Git Product home page Git Product logo

functorch's Introduction

functorch

Why functorch? | Install guide | Transformations | Documentation | Future Plans

This library is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue or reach out. We'd love to hear about how you're using the library.

functorch is JAX-like composable function transforms for PyTorch.

It aims to provide composable vmap and grad transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance.

In addition, there is experimental functionality to trace through these transformations using FX in order to capture the results of these transforms ahead of time. This would allow us to compile the results of vmap or grad to improve performance.

Why composable function transforms?

There are a number of use cases that are tricky to do in PyTorch today:

  • computing per-sample-gradients (or other per-sample quantities)
  • running ensembles of models on a single machine
  • efficiently batching together tasks in the inner-loop of MAML
  • efficiently computing Jacobians and Hessians
  • efficiently computing batched Jacobians and Hessians

Composing vmap, grad, vjp, and jvp transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the JAX framework.

Install

There are two ways to install functorch:

  1. functorch from source
  2. functorch beta (compatible with recent PyTorch releases)

We recommend trying out the functorch beta first.

Installing functorch from source

Click to expand

Using Colab

Follow the instructions in this Colab notebook

Locally

As of 9/21/2022, functorch comes installed alongside a nightly PyTorch binary. Please install a Preview (nightly) PyTorch binary; see https://pytorch.org/ for instructions.

Once you've done that, run a quick sanity check in Python:

import torch
from functorch import vmap
x = torch.randn(3)
y = vmap(torch.sin)(x)
assert torch.allclose(y, x.sin())

functorch development setup

As of 9/21/2022, functorch comes installed alongside PyTorch and is in the PyTorch source tree. Please install PyTorch from source, then, you will be able to import functorch.

Try to run some tests to make sure all is OK:

pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -v

AOTAutograd has some additional optional requirements. You can install them via:

pip install networkx

To run functorch tests, please install our test dependencies (expecttest, pyyaml).

Installing functorch beta (compatible with recent PyTorch releases)

Click to expand

Using Colab

Follow the instructions here

pip

Prerequisite: Install PyTorch

pip install functorch

Finally, run a quick sanity check in python:

import torch
from functorch import vmap
x = torch.randn(3)
y = vmap(torch.sin)(x)
assert torch.allclose(y, x.sin())

What are the transforms?

Right now, we support the following transforms:

  • grad, vjp, jvp,
  • jacrev, jacfwd, hessian
  • vmap

Furthermore, we have some utilities for working with PyTorch modules.

  • make_functional(model)
  • make_functional_with_buffers(model)

vmap

Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.

vmap(func)(*inputs) is a transform that adds a dimension to all Tensor operations in func. vmap(func) returns a new function that maps func over some dimension (default: 0) of each Tensor in inputs.

vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with vmap(func), leading to a simpler modeling experience:

from functorch import vmap
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)

def model(feature_vec):
    # Very simple linear model with activation
    assert feature_vec.dim() == 1
    return feature_vec.dot(weights).relu()

examples = torch.randn(batch_size, feature_size)
result = vmap(model)(examples)

grad

grad(func)(*inputs) assumes func returns a single-element Tensor. It compute the gradients of the output of func w.r.t. to inputs[0].

from functorch import grad
x = torch.randn([])
cos_x = grad(lambda x: torch.sin(x))(x)
assert torch.allclose(cos_x, x.cos())

# Second-order gradients
neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
assert torch.allclose(neg_sin_x, -x.sin())

When composed with vmap, grad can be used to compute per-sample-gradients:

from functorch import vmap
batch_size, feature_size = 3, 5

def model(weights,feature_vec):
    # Very simple linear model with activation
    assert feature_vec.dim() == 1
    return feature_vec.dot(weights).relu()

def compute_loss(weights, example, target):
    y = model(weights, example)
    return ((y - target) ** 2).mean()  # MSELoss

weights = torch.randn(feature_size, requires_grad=True)
examples = torch.randn(batch_size, feature_size)
targets = torch.randn(batch_size)
inputs = (weights,examples, targets)
grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)

vjp

The vjp transform applies func to inputs and returns a new function that computes vjps given some cotangents Tensors.

from functorch import vjp
outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)

jvp

The jvp transforms computes Jacobian-vector-products and is also known as "forward-mode AD". It is not a higher-order function unlike most other transforms, but it returns the outputs of func(inputs) as well as the jvps.

from functorch import jvp
x = torch.randn(5)
y = torch.randn(5)
f = lambda x, y: (x * y)
_, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
assert torch.allclose(output, x + y)

jacrev, jacfwd, and hessian

The jacrev transform returns a new function that takes in x and returns the Jacobian of torch.sin with respect to x using reverse-mode AD.

from functorch import jacrev
x = torch.randn(5)
jacobian = jacrev(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)

Use jacrev to compute the jacobian. This can be composed with vmap to produce batched jacobians:

x = torch.randn(64, 5)
jacobian = vmap(jacrev(torch.sin))(x)
assert jacobian.shape == (64, 5, 5)

jacfwd is a drop-in replacement for jacrev that computes Jacobians using forward-mode AD:

from functorch import jacfwd
x = torch.randn(5)
jacobian = jacfwd(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)

Composing jacrev with itself or jacfwd can produce hessians:

def f(x):
  return x.sin().sum()

x = torch.randn(5)
hessian0 = jacrev(jacrev(f))(x)
hessian1 = jacfwd(jacrev(f))(x)

The hessian is a convenience function that combines jacfwd and jacrev:

from functorch import hessian

def f(x):
  return x.sin().sum()

x = torch.randn(5)
hess = hessian(f)(x)

Tracing through the transformations

We can also trace through these transformations in order to capture the results as new code using make_fx. There is also experimental integration with the NNC compiler (only works on CPU for now!).

from functorch import make_fx, grad
def f(x):
    return torch.sin(x).sum()
x = torch.randn(100)
grad_f = make_fx(grad(f))(x)
print(grad_f.code)

def forward(self, x_1):
    sin = torch.ops.aten.sin(x_1)
    sum_1 = torch.ops.aten.sum(sin, None);  sin = None
    cos = torch.ops.aten.cos(x_1);  x_1 = None
    _tensor_constant0 = self._tensor_constant0
    mul = torch.ops.aten.mul(_tensor_constant0, cos);  _tensor_constant0 = cos = None
    return mul

Working with NN modules: make_functional and friends

Sometimes you may want to perform a transform with respect to the parameters and/or buffers of an nn.Module. This can happen for example in:

  • model ensembling, where all of your weights and buffers have an additional dimension
  • per-sample-gradient computation where you want to compute per-sample-grads of the loss with respect to the model parameters

Our solution to this right now is an API that, given an nn.Module, creates a stateless version of it that can be called like a function.

  • make_functional(model) returns a functional version of model and the model.parameters()
  • make_functional_with_buffers(model) returns a functional version of model and the model.parameters() and model.buffers().

Here's an example where we compute per-sample-gradients using an nn.Linear layer:

import torch
from functorch import make_functional, vmap, grad

model = torch.nn.Linear(3, 3)
data = torch.randn(64, 3)
targets = torch.randn(64, 3)

func_model, params = make_functional(model)

def compute_loss(params, data, targets):
    preds = func_model(params, data)
    return torch.mean((preds - targets) ** 2)

per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)

If you're making an ensemble of models, you may find combine_state_for_ensemble useful.

Documentation

For more documentation, see our docs website.

Debugging

torch._C._functorch.dump_tensor: Dumps dispatch keys on stack torch._C._functorch._set_vmap_fallback_warning_enabled(False) if the vmap warning spam bothers you.

Future Plans

In the end state, we'd like to upstream this into PyTorch once we iron out the design details. To figure out the details, we need your help -- please send us your use cases by starting a conversation in the issue tracker or trying our project out.

License

Functorch has a BSD-style license, as found in the LICENSE file.

Citing functorch

If you use functorch in your publication, please cite it by using the following BibTeX entry.

@Misc{functorch2021,
  author =       {Horace He, Richard Zou},
  title =        {functorch: JAX-like composable function transforms for PyTorch},
  howpublished = {\url{https://github.com/pytorch/functorch}},
  year =         {2021}
}

functorch's People

Contributors

alband avatar andrechang avatar anijain2305 avatar bdhirsh avatar bertmaher avatar chillee avatar cyyever avatar eellison avatar erotemic avatar ezyang avatar gprateek93 avatar ivanyashchuk avatar kit1980 avatar krovatkin avatar kshitij12345 avatar lezcano avatar malfet avatar ngimel avatar nkaretnikov avatar padarn avatar pytorchmergebot avatar sherlocknomad avatar soulitzer avatar t-vi avatar vfdev-5 avatar wconstab avatar yueyericardo avatar yushangdi avatar zdevito avatar zou3519 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

functorch's Issues

Print vmap warnings by default

  1. torch._C._debug_only_display_vmap_fallback_warnings -> Use a functorch API instead of this
  2. Turn on warnings by default so that it is clear we are not promising good perf

grad doesn't work with _VF.frobenius_norm

from functorch import jacrev, vmap, grad

def f(x):
    return torch.norm(x, dim=1).sum()

print(grad(f)(torch.randn(3, 3)))
Traceback (most recent call last):
  File "t.py", line 39, in <module>
    print(grad(f)(torch.randn(3, 3)))
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 149, in wrapper
    results = grad_and_value(f, argnums, has_aux=has_aux)(*args)
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 107, in wrapper
    output = f(*args)
  File "t.py", line 37, in f
    return torch.norm(x, dim=1).sum()
  File "/home/chilli/fb/pytorch/torch/functional.py", line 1441, in norm
    return _VF.frobenius_norm(input, _dim, keepdim=keepdim)
NotImplementedError: Cannot access storage of TensorWrapper

Haven't dug into this issue - this might be the same root cause as others - i.e: torch.tensor

grad doesn't run when under `torch.no_grad()`

from functorch import grad, vmap, pythonkey_trace, wrap_key
import torch
import torch.fx as fx

def f(x):
    return torch.sin(x)
with torch.no_grad():
    print(grad(f)(torch.randn(())))
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Not totally sure what the semantics should be... but I kinda think we should be ignoring torch.no_grad().

CUBLAS_STATUS_ALLOC_FAILED when jacrev of jacrev of matmul

import torch
from functorch import jacrev

device = 'cuda'

N = 5
M = 3
W = torch.randn(N, M, device=device)

def f(x):
    return W @ x

x = torch.randn(M)
result = jacrev(jacrev(f))(x)
expected = torch.zeros(N, M, M, device=device)
assert torch.allclose(result, expected)

Running list of batching rules that needs to be implemented

For any examples we're running, if we see fallback warnings, we can add it to the list here so that we have a list of batching rules we can chip away at. We can put our name next to one if we're planning on adding it.

Batching Rules needed for Omniglot:

  • aten::mkldnn_convolution
  • aten::native_batch_norm
  • aten::nll_loss_forward
  • aten::nll_loss_backward
  • aten::_log_softmax_backward_data
  • aten::max_pool2d_with_indices_backward
  • aten::threshold_backward
  • aten::native_batch_norm_backward
  • aten::mkldnn_convolution_backward
  • aten::conv2d
  • aten::batch_norm
  • aten::linear
  • aten::nll_loss_nd
  • aten::argmax
  • aten::eq.Tensor @zou3519

Parallel Train:

  • aten::nll_loss_forward
  • aten::nll_loss_backward
  • aten::_log_softmax_backward_data
  • aten::threshold_backward @zou3519

DP Cifar10:

  • aten::mkldnn_convolution
  • aten::native_group_norm
  • aten::relu_ @zou3519
  • aten::thnn_conv2d_forward
  • aten::add_.Tensor @zou3519
  • aten::nll_loss_forward
  • aten::nll_loss_backward
  • aten::_log_softmax_backward_data
  • aten::threshold_backward @zou3519
  • aten::reciprocal_ @zou3519
  • aten::clamp_min: Warning: make sure this compiles using clang too
  • aten::thnn_conv2d_backward.output_mask
  • aten::max_pool2d_with_indices_backward
  • aten::mkldnn_convolution_backward
  • aten::cudnn_convolution
  • aten::cudnn_convolution_backward
  • aten::native_group_norm

From #26:

  • aten::rsub.Scalar

  • aten::diag

  • aten::where.Scalar

  • aten::allclose

  • advanced indexing (index, index_put_)

Top 100 torch.foo:

  • t 6837449
  • tensor 585786
  • mode 462182
  • cat 394818
  • max 368038
  • zeros 329495
  • load 327756
  • no_grad 294694
  • save 265130
  • from_numpy 243063
  • manual_seed 165044
  • ones 153696
  • randn 150796
  • stack 133358
  • sum 130772
  • arange 98087
  • rand 94715
  • mean 88546
  • exp 73883
  • zeros_like 72831
  • min 72248
  • sigmoid 66798
  • log 62135
  • matmul 47811
  • clamp 45304
  • sqrt 44911
  • abs 43535
  • tanh 42793
  • empty 40311
  • argmax 38435
  • bmm 33984
  • pow 33571
  • norm 31125 (deprecated?)
  • mm 30995
  • is_tensor 29546
  • ones_like 29512
  • nonzero 28681 (dynamic)
  • full 28373
  • unsqueeze 27911
  • where 26585
  • randperm 26450 (random)
  • eye 24342
  • mul 23236
  • topk 22537
  • as_tensor 21967
  • sort 21412
  • squeeze 20863
  • randint 20771 (random)
  • linspace 20041
  • add 19201
  • transpose 18663
  • split 18325
  • gather 17904
  • set_grad_enabled 16013
  • sin 15669
  • cos 15562
  • div 15513
  • index_select 14866
  • multinomial 14331 (random)
  • flatten 14267
  • isnan 14170
  • randn_like 13096 (random)
  • eq 12680
  • einsum 12480
  • round 12367
  • floor 11628
  • allclose 11000
  • reshape 10605
  • diag 10167
  • chunk 9581
  • std 9379
  • set_default_tensor_type 9281
  • triu 8559
  • meshgrid 8292
  • set_num_threads 8126
  • unique 7964 (dynamic)
  • full_like 7780
  • tril 7538
  • dot 7275
  • sign 6943
  • equal 6916
  • normal 6750 (random)
  • cumsum 6556
  • dist 6058
  • isfinite 6030
  • gt 5935
  • set_printoptions 5888
  • range 5491
  • empty_like 5351
  • flip 5342
  • masked_select 5341 (sometimes dynamic)
  • bernoulli 5262 (random)
  • atan 5253
  • var 5247
  • prod 5200
  • erf 5088
  • inverse 5072
  • addmm 4854
  • logsumexp 4582

CompositeImplicitAutograd ops that call *_like or new_* operators fail under certain transforms

import torch
import functorch
from functorch import vmap, grad

N = 3
C = 5

device = 'cpu'

def foo(x):
    result = x.contiguous()
    return result.sum()

x = torch.randn(C, N, device=device).t()
result = vmap(grad(foo))(x)

fails with

RuntimeError: vmap: aten::copy_(self, *extra_args) is not possible because there exists a Tensor `other` in
extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` no
t being vmapped over at level 2. Please try to use out-of-place operators instead of aten::copy_. If said op
erator is being called inside the PyTorch framework, please file a bug report instead.

Figure out API for working with an ensemble of modules

Problem: How does one initialize an ensemble of modules?

One possible solution is to offer an API that returns parameters with an additional "ensemble" dimension. For example, if we were trying to ensemble models that contained a single nn.Linear layer, then we'd return a weight and bias each with an extra ensemble dimension.

That leads to something like the following API:

state_dict = functional_init_ensemble(nn.Linear, 3, 3, ensemble_size=5, device='cpu')

This returns a state_dict that has two elements:

  • weight with shape (5, 3, 3)
  • bias with shape (5, 3)

There are some problems with returning a state dict:

  • need some way of separating out buffers and parameters

Another way to do this is to have all nn.Modules take in an additional 'ensemble_size' dimension and straight up just return nn.Modules...

module = nn.Linear(3, 3, ensemble_size=1)

Some requirements:

  • this should work on user-defined modules as well

Decompose CompositeImplicitAutograd ops at the FuncTorchBatched key

Background

@ezyang suggested to try this to minimize the number of operators we have to override. More concretely, instead of registering all 2000 operators to FuncTorchBatched; we only have to register (insert number here) of operators that are not composite w.r.t. autograd.

To be concrete, the suggestion was to add FuncTorchBatched to https://github.com/pytorch/pytorch/blob/8dd0570b34c7c378ae9729c21267546cba07fdc9/c10/core/DispatchKeySet.cpp#L28-L32

The experiment

I added FuncTorchBatched to https://github.com/pytorch/pytorch/blob/8dd0570b34c7c378ae9729c21267546cba07fdc9/c10/core/DispatchKeySet.cpp#L28-L32, recompiled PyTorch and functorch, and then ran the test suite. This leads to a fun number of failures (see here) that have the same root cause!

The problem is that some CompositeImplicitAutograd ops decompose to in-place operations that are not compatible with vmap (note here).

Can we solve these problems by just registering an override for the vmap key for those operations?

  • that would solve the vmap(blah) problem but I'm not sure because a vmap(grad(blah)) is always going to decompose blah since it runs through the grad transform.

test/test_eager_transforms.py::TestVmapOfGradCPU::test_log_softmax_cpu is broken

RuntimeError: backward() called inside torch.vmap. This is not supported, please call backward() outside torch.vmap or instead use torch.autograd.grad inside torch.vmap

This is what that test looks like:

def test_log_softmax(self, device):
    x = torch.randn(3, 5)
    v = torch.randn(5)

    def foo(x, v):
        _, vjp_fn = vjp(partial(torch.log_softmax, dim=-1), x)
        return vjp_fn(v)[0]

    result = vmap(foo, (0, None))(x, v)

    v = v.expand_as(x)
    x.requires_grad_()
    output = torch.log_softmax(x, dim=-1)
    output.backward(v)
    self.assertEqual(result, x.grad)

Set up a CI and figure out how functorch depends on PyTorch master

The current scheme we're working with is "functorch should always work with the existing PyTorch viable/strict branch".

Motivations:

  • If a change to PyTorch core requires a change to functorch, it would be nice to catch it sooner than later
  • As the functorch test suite balloons, it is nice to have tests for committed code.

TODO:

  • Initial functorch CI, runs all tests except test_pythonkey.py #53
  • Add a build for LLVM PyTorch that runs test_pythonkey.py

Improve make_functional*

Things that we can and should do now:

  • make_functional* should not destroy the original model
  • Let func be the function returned by make_functional_with_buffers. func should accept arguments as func(params, buffers, *args, **kwargs) instead of func(params, buffers, args) (what it currently accepts).
  • * We can make func into a special FunctionalModule class so it is registered as a subclass of nn.Module. This makes it so that func.eval() and func.train() work; furthermore, func can be registered as a submodule of an owning Module. func is still callable like a function and has no state.
  • We should probably provide a helper function to โ€œstackโ€ weights and buffers to prepare them for vmap
    Probably just accept a list of the โ€œsameโ€ module and we will stack them together.

Things we should consider but are more tricky:

  • Combine the weights and buffers return value. The difficulty around this is the interaction with functorch.grad

jacrev(jacrev(f)) fails for matmul

from functorch import jacrev
N = 5
M = 3
W = torch.randn(N, M)
def f(x):
    return W @ x
inps = (torch.randn(M),)
print(jacrev(jacrev(f))(*inps))
Traceback (most recent call last):
  File "python_key.py", line 70, in <module>
    print(jacrev(jacrev(f))(*inps))
  File "/opt/anaconda/lib/python3.7/site-packages/functorch-0.0.1a0+e27e16f-py3.7-linux-x86_64.egg/functorch/_src/eager_transforms.py", line 87, in wrapper_fn
    result, = vmap(vjp_fn)(basis)
  File "/opt/anaconda/lib/python3.7/site-packages/functorch-0.0.1a0+e27e16f-py3.7-linux-x86_64.egg/functorch/_src/vmap.py", line 258, in wrapped
    batched_outputs = func(*batched_inputs)
  File "/opt/anaconda/lib/python3.7/site-packages/functorch-0.0.1a0+e27e16f-py3.7-linux-x86_64.egg/functorch/_src/eager_transforms.py", line 72, in wrapper
    retain_graph=retain_graph, create_graph=create_graph)
  File "/home/chilli/fb/pytorch/torch/autograd/__init__.py", line 228, in grad
    inputs, allow_unused, accumulate_grad=False)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Please keep up the great work!!

This is not a bug report or feature request, but more a shout of admiration. I am a big-time PyTorch fan, but have been looking at JAX because they support vmap and other awesome functional tools. If PyTorch succeeds in reproducing these features, I think it takes one major reason away from switching to JAX.

Please, please keep up the good work and make this happen! Thanks for listening to your community and taking the initiative. You guys are the best!

Should matmul have a decomposed batch rule or an actual one?

Right now the batch rule for matmul is decomposed (https://github.com/zou3519/functorch/blob/53144b92d33d6d796359c97764ee68743f5463bf/functorch/csrc/BatchingRegistrations.cpp#L1254).

My worry is that it might be possible for us to transform some code into inefficient code. For example, if B0 and B1 are vmap dimensions and we are matrix-multiplying tensor of size [B0, 5, 5], [B1, 5, 5], we don't want to multiply tensors of size [B0, 1, 5, 5] and [1, B1, 5, 5]. If that happens, then internally, matmul will expand the tensors to [B0, B1, 5, 5] and materialize the full memory, which can be quite slow. (The ideal way to multiply these tensors is to reshape them into [B0 * 5, 5] and [5, B1 * 5], and then multiply them together).

This issue is probably just a code reading exercise to see if it's possible for the above to happen in the decomposed matmul code. I was in the middle of writing a non-decomposed matmul here: https://gist.github.com/zou3519/ddd4b2d4aacc98bf20d114f26b27b082

Ops used in torch.distributions to lower

HalfCauchy:

{'div', 'pow', 'sum', 'log', 'lt', 'exp', 'add', 'mul', 'index_put_', 'neg', 'unsqueeze', 'expand', 'sub', '_local_scalar_dense'}

LKJCholesky:

{'ge', 'softplus_backward', 'sum', 'tanh_backward', 'softplus', 'sub', 'tanh', 'add', 'mul', 'neg', '_s_where', 'clamp', 'logical_and', 'le', 'expand', 'index_put_'}

Uniform:

{'div', 'sigmoid', 'sum', 'softplus_backward', 'softplus', 'copy_', 'add', 'mul', 'neg', 'clamp', 'le', 'gt', 'sub'}

Bernoulli:

{'neg', 'binary_cross_entropy_with_logits', 'sub', 'sum'}

Beta:

{'sigmoid_backward', 'sigmoid', 'unbind', 'ge', 'sum', 'softplus', 'mul', 'stack', 'sub', 'log', 'softplus_backward', 'add', 'rsub', 'neg', 'logical_and', 'div', 'clamp', 'le', '_s_where'}

Dirichlet:

{'sigmoid_backward', 'sigmoid', 'log_sigmoid_backward', 'expand', 'ge', 'slice', 'getitem', 'sum', 'mul', 'constant_pad_nd', 'log', 'sub', 'add', 'rsub', 'neg', 'logical_and', 'log_sigmoid_forward', 'div', 'cumprod', 'copy_', 'clamp', 'le', '_s_where'}

HalfNormal:

index_select

Figure out how to transform over optimizers

One way to transform over training loops (e.g. to do model ensembling or the inner step of a MAML) is to use a function that represents the optimizer step instead of an actual PyTorch optimizer. Right now I think we have the following requirements

  • There should be a function version of each optimizer (e.g. F.sgd)
  • The function should have an option to not mutate (e.g. F.sgd(..., inplace=False))
  • The function should be differentiable

PyTorch already has some here (in Prototype stage): https://github.com/pytorch/pytorch/blob/master/torch/optim/_functional.py, so we should check if these fit the requirements, and, if not, decide if we should influence the design

Multiple Inner Loops

Great project --

What would be the most straightforward way to allow for multiple inner training loops?

Specifically with regards to the MAML example, how could I allow for a user defined number of inner loops?

def get_loss_for_task(x1, y1, x2, y2):
        def inner_loss(params, x1, y1):
            f = net(params, (x1,))
            loss = mse_loss(f, y1)
            return loss

        grads = grad(inner_loss)(params, x1, y1)
        new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]

        v_f = net(new_params, (x2,))
        return mse_loss(v_f, y2)

As a note, after changing losses.append(loss2) to losses.append(loss2.item()) I was able to plot the results

Handle namedtuples and PyTorch's special return types correctly

Right now, the vmap and grad transforms will ignore namedtuples and pretend they are tuples. This leads to the names getting stripped. We should change pytrees to support named tuples and PyTorch's special return types. I don't know if PyTorch's special return types are actually named tuples though.

Handle cases where the gradients are 0 for inputs

from functorch import grad
import torch

def f(x):
    return (x[0]**2.0).sum()
inps = (torch.randn(3), torch.randn(3))
fx_graph = grad(f)(inps)

Error:

Traceback (most recent call last):
  File "t.py", line 7, in <module>
    fx_graph = grad(f)(inps)
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 155, in wrapper
    results = grad_and_value(f, argnums, has_aux=has_aux)(*args)
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 135, in wrapper
    output, flat_diff_args, create_graph=True)
  File "/home/chilli/fb/pytorch/torch/autograd/__init__.py", line 228, in grad
    inputs, allow_unused, accumulate_grad=False)
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

PyTorch Lightning Integration

Hey there,

Awesome work with this library !

I am part of the PyTorch Lightning team and
we could integrate FuncTorch pretty easily before it gets upstreamed to PyTorch. In Lightning, it will be added as a new TrainingType Plugin and it should make benchmarking simpler for you as many models are already implemented.

It could look like Trainer(accelerator='pmap').

If you are interested, please join Lightning Slack and PM me ๐Ÿค—

Best,
T.C

Transforms don't work with new_ones

import torch
from functorch import vmap,

def f(x):
    return x.new_ones(x.shape)

print(vmap(f)(torch.randn(3)))

>>> RuntimeError: DispatchKey FuncTorchBatched doesn't correspond to a device

Prototype vmap over data-dependent control flow

There needs to be something to:

  1. capture control flow
  2. represent control flow in a form that is transformable
  3. actually transform the control flow (e.g. a batching rule)

We don't have to worry too much about (1) for now. A way to prototype (2) would be to have something like control flow operators. These can either be python-based, or go through the PyTorch C++ dispatcher (!!).

indexing with a `True` tensor fails under grad

import torch
from functorch import grad

def f(value):
    log_prob = torch.ones(())
    val = (torch.zeros(()) > 0)
    log_prob[val] = 0
    return value

grad(f)(torch.randn(()))

>>> Traceback (most recent call last):
  File "t.py", line 13, in <module>
    grad(f)(torch.randn(()))
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 178, in wrapper
    results = grad_and_value(f, argnums, has_aux=has_aux)(*args)
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 142, in wrapper
    output = f(*args)
  File "t.py", line 10, in f
    log_prob[val] = 0
NotImplementedError: Cannot access storage of TensorWrapper

Can't call `torch.tensor` within grad

from functorch import grad, vmap
import torch

def f(x):
    t = torch.tensor(0)
    return t + x
inps = (torch.randn([]),)
print(grad(f)(*inps))
Traceback (most recent call last):
  File "t.py", line 8, in <module>
    print(grad(f)(*inps))
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 152, in wrapper
    results = grad_and_value(f, argnums, has_aux=has_aux)(*args)
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 110, in wrapper
    output = f(*args)
  File "t.py", line 5, in f
    t = torch.tensor(0)
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

vjp is silently incorrect when copy_ is involved

Repro:

import torch
import functorch
from functorch import vjp

torch.manual_seed(0)

x = torch.randn(3, 5)
gy = torch.randn(3)

ggx = torch.arange(15, dtype=torch.float).view(3, 5)

def gp(x, gy):
    res = torch.zeros(3, 5)
    res.diagonal(2).copy_(gy)
    return res

gx, vjp_fn = vjp(gp, x, gy)
result = vjp_fn(ggx)

expected = torch.diag(ggx, 2)
print(result[1])
print(expected)
assert torch.allclose(result[1], expected)

Batching rule not implemented for aten::rsub.Scalar, aten::diag, aten::where.Scalar, aten::allclose.

Hello and thank you for the great project:
We are using it in a fully differentiable physics code atm and found some use cases that were not covered yet.
We are vmapping some functions that use torch.where and/or torch.diag e.g.:

def set_diagonal_to_inf(hamiltonian, value=10e9):
    """Args:
        hamiltonian: Matrix of hamiltonian (N_k, particle number*N_orbitals, particle number*N_orbitals)
        value: int/ float value the zeros are set to
    Returns:
        Hamiltonian with high eigenvalues for non existing particles/orbitals
    """
    diag = torch.sum(torch.abs(hamiltonian), axis=0)
    diag = torch.where(diag == 0, value, 0.)
    return torch.diag(diag)

and got the warning that they were not implemented yet:
functorch/_src/vmap.py:268: UserWarning: Batching rule not implemented for aten::diag falling back to slow (for loop and stack) implementation (Triggered internally at /tmp/pip-req-build-_exf5lof/functorch/csrc/BatchedFallback.cpp:89.)
batched_outputs = func(*batched_inputs)
UserWarning: Batching rule not implemented for aten::where.Scalar falling back to slow (for loop and stack) implementation (Triggered internally at /tmp/pip-req-build-_exf5lof/functorch/csrc/BatchedFallback.cpp:89.)

Another function was constructing a 2d array from an input vector, where each array element was constructed with a different formula. That function was performance relevant but we were able to vectorize it without vmap, anyway it lead to the following warning.
torch/_tensor.py:544: UserWarning: Batching rule not implemented for aten::rsub.Scalar falling back to slow (for loop and stack) implementation (Triggered internally at /tmp/pip-req-build-_exf5lof/functorch/csrc/BatchedFallback.cpp:89.)
return _C._VariableFunctions.rsub(self, other)

At last we would like to solve a lot of eigenvalue problems with the symeig solver from xitorch https://github.com/xitorch/xitorch

def eig(ham):
     ham =(ham+ham.T)/2
     ham = xitorch.LinearOperator.m(ham)
     return  xilinalg.symeig(ham)[0]

parallel_solve = vmap(eig,0)
test = torch.rand(7,10,10)
parallel_solve(test)

which lead to the following error
File "xitorch/_core/linop.py", line 100, in m
is_hermitian = torch.allclose(mat, mat.transpose(-2, -1))
RuntimeError: Batching rule not implemented for aten::allclose. We could not generate a fallback.

Would be great if you can find the time to add some of these, especially the last one.
Ps: thank you for using the same vmap syntax as jax that saved a lot of time converting the code.

Question: `grad` and static compute graphs?

Hi all,

This looks like an awesome idea โ€” it would be amazing to combine some of the functional transformations abilities of the smaller frameworks with the power of PyTorch!

Looking briefly at the code, it looks like your grad generates calls back into the autograd engine. Do you have any plans, along the lines of pytorch/pytorch#35215, to enable generating static compute graphs for the derivative of a function? Or is your plan to always use the dynamic autograd engine?
Thanks!

Batch rule plumbing codegen

Problem: Writing plumbing is repetitive, see link

We should have some way of auto-generating the plumbing and allowing a developer to insert some dispatching logic into the middle of the plumbing.

Proposal 1: Macro our way to victory

Every op gets a OP_PLUMBING_START and a OP_PLUMBING_END macro. Inside BatchRulesLoss.cpp, here's how we would write the plumbing for nll_loss_forward:

nll_loss_forward_PLUMBING_BEGIN
  if (!self_bdim && !target_bdim && !weight_bdim) {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    return at::nll_loss_forward(self_value, target_value, weight_value, reduction, ignore_index);
  }

  if (self_bdim && target_bdim && (!weight || !weight->defined()) && ignore_index < 0) {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    auto results = nll_loss_forward_self_target_batch_rule(
        self_value, self_bdim, target_value, target_bdim, reduction);
    return std::make_tuple(
      makeBatched(std::get<0>(results), std::get<1>(results), cur_level),
      makeBatched(std::get<2>(results), std::get<3>(results), cur_level)
    );
  }
nll_loss_backward_PLUMBING_END


TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
  m.impl("nll_loss_forward", nll_loss_forward_plumbing);
}

Proposal 2: Have a "batch_rules.yaml" file

The .yaml file could handle the registration (e.g. m.impl("nll_loss_forward", nll_loss_forward_plumbing);) as well

 - name: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)
 - dispatch: >
  if (!self_bdim && !target_bdim && !weight_bdim) {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    return at::nll_loss_forward(self_value, target_value, weight_value, reduction, ignore_index);
  }

  if (self_bdim && target_bdim && (!weight || !weight->defined()) && ignore_index < 0) {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    auto results = nll_loss_forward_self_target_batch_rule(
        self_value, self_bdim, target_value, target_bdim, reduction);
    return std::make_tuple(
      makeBatched(std::get<0>(results), std::get<1>(results), cur_level),
      makeBatched(std::get<2>(results), std::get<3>(results), cur_level)
    );
  }

I don't like either solution so haven't implemented any of them yet.

Error while installing functorch

I am trying to install functorch for doing some tests with vmap but I am not being able to install it following the instructions in the README. I'm just trying to run the Colab, but I'm getting the following error:

Collecting git+https://github.com/zou3519/functorch.git
  Cloning https://github.com/zou3519/functorch.git to /tmp/pip-req-build-57x_64b3
  Running command git clone -q https://github.com/zou3519/functorch.git /tmp/pip-req-build-57x_64b3
Requirement already satisfied: torch>=1.9.0.dev in /usr/local/lib/python3.7/dist-packages (from functorch==0.0.1a0+2890f63) (1.9.0.dev20210429+cpu)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.9.0.dev->functorch==0.0.1a0+2890f63) (3.7.4.3)
Building wheels for collected packages: functorch
  Building wheel for functorch (setup.py) ... error
  ERROR: Failed building wheel for functorch
  Running setup.py clean for functorch
Failed to build functorch
Installing collected packages: functorch
    Running setup.py install for functorch ... error
ERROR: Command errored out with exit status 1: /usr/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-req-build-57x_64b3/setup.py'"'"'; __file__='"'"'/tmp/pip-req-build-57x_64b3/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' install --record /tmp/pip-record-lv6v2_fa/install-record.txt --single-version-externally-managed --compile --user --prefix= Check the logs for full command output.

Any idea how to solve it?

grad transform failing tests tracking issue

Current fails for:

Calls internal_new_from_data: (#65)

  • __getitem__
  • __rpow__ (straight up calls torch.tensor)
  • torch.tensor
  • Tensor.new()

Data pointer accessed by helper function (#65)

  • linalg.cholesky (linalg_cholesky calls linalg_cholesky_ex (prim) and does error checking)
  • linalg.inv (linalg_inv calls linalg_inv_ex (prim) and does error checking)
  • linalg.matrix_power (can call inv)

The norm problem (#14); AKA: CompositeImplicitAutograd op calls an "out= variant" that calls raw native::resize_ on tensors.

  • linalg.matrix_norm
  • linalg.norm
  • nanquantile
  • quantile

Requires an integer tensor for the "splits" argument...

  • tensor_split

Test by uncommenting out https://github.com/zou3519/functorch/blob/ae97def8eb8508418053a1a7c81371b9b44dcc3d/test/test_grad.py#L49. I haven't investigated the problems yet.

Miscellaneous non-OpInfo problems (test_torch.py)

  • Tensor.numpy
  • Tensor.tolist
  • copy.copy
  • to_dlpack
  • repeat_interleave
  • Tensor.map_
  • Tensor.map2_
  • pickle.dumps
  • printing
  • torch.sobol_engine_initialize_state
  • assigning to Tensor.data

Miscellaneous non-OpInfo problems (test_nn.py)

  • F.ctc_loss
  • F.max_pool1d (testing artifact)
  • Lazy modules

Miscellaneous non-OpInfo problems (test_linalg.py)

Miscellaneous non-OpInfo problems (test_tensor_creation.py)

Miscellaneous non-OpInfo problems test_unary_ufuncs.py

  • conj

https://docs.google.com/spreadsheets/d/18sv-cKBqMGVCNdclFk5jB9LmQJGzb_eNAE9O2-oep3Q/edit?usp=sharing

batching rule for repeat has some cases it fails on

def f(x):
    return x.repeat(0)

print(vmap(f)(torch.randn(3)).shape) # Returns shape [1,0]

This one is kind of awkward (I'm surprised it's legal), but I think it makes sense to preserve the invariant that the output always has shape B in the batching dim.

def f(x):
    return x.repeat(1)

print(vmap(f)(torch.randn((3))).shape) # Returns shape [1,3], should return [3,1]

Transform testing tracking issue

Add OpInfo-based testing for:

  • vmap
  • grad
  • vjp
  • vjp of vjp
  • vjp of vmap
  • vmap of vjp
  • vmap of vmap

grad is really a special case of vjp. Do we need more tests for it?

  • See if we can implement grad by calling vjp.
  • There are a lot of exceptions for in-place operations. See if we can/should test in-place in OpInfo testing
  • Tests to check which ops have batching rules is useful to make sure we actually register the batching rules correctly
  • Stress testing for TensorWrapper: Wrap all Tensors in TensorWrappers and send them through the PyTorch test suite...

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.