Comments (5)
In 2.0 you can access the APIs from both torch.func
and functorch
. This will be true for the foreseeable future (i.e., we will preserve BC for a few releases of PyTorch). However, there will be differences between the torch.func.*
and functorch.*
APIs
In general, we're deprecating the functorch.* APIs in favor of the torch.func.* APIs. As a part of this deprecation, we're moving away from functorch.make_functional and consolidating on PyTorch's NN stateless API. More details over at pytorch/pytorch#91811
from functorch.
In general, we're deprecating the
functorch.*
APIs in favor of thetorch.func.*
APIs.
@zou3519 Thanks for the comment.
As a part of this deprecation, we're moving away from
functorch.make_functional
and consolidating on PyTorch's NN stateless API. More details over at pytorch/pytorch#91811
One more question about the memory usage of PyTorch's NN stateless API.
We will prefer torch.func.functional_call
over functorch.make_functional
in the future:
import functorch
import torch
model = ... # build NN module
# functional_call
params_and_buffers_dict = ... # extract parameters or user-defined tensor dicts
output = torch.func.functional_call(model, params_and_buffers_dict, args=args, kwargs=kwargs)
# make_functional
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
output = fmodel(params, buffers, *args, **kwargs)
In functorch.make_functional{,_with_buffers}
, the copied stateless module converts tensors with meta
device, which does not hold data storage. This makes the fmodel
use significantly less memory than the original module. Now, the nn.utils.stateless.functional_call
API requires the user to pass a full model and a new copy of parameters. That is twice the memory usage. This memory problem may be exacerbated when multi-process communication is required.
For example, stateless functional call over RPC:
import torch
import torch.distributed.rpc as rpc
model = ... # build NN module
# functional_call
params_and_buffers_dict = ... # extract parameters or user-defined tensor dicts
output = rpc.rpc_sync(
'worker1',
torch.func.functional_call,
args=(
model, # the original parameters also need communication
params_and_buffers_dict,
),
kwargs=dict(args=args, kwargs=kwargs),
)
# make_functional
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
output = rpc.rpc_sync(
'worker1',
fmodel, # small serialization and communication overhead
args=(params, buffers, *args),
kwargs=kwargs,
)
In order to have less communication cost, users need to explicitly convert the tensors to meta
device before stateless functional call on remote workers:
import torch
import torch.distributed.rpc as rpc
model = ... # build NN module
# functional_call
params_and_buffers_dict = ... # extract parameters or user-defined tensor dicts
output = rpc.rpc_sync(
'worker1',
torch.func.functional_call,
args=(
model.to('meta'), # convert to meta device
params_and_buffers_dict,
),
kwargs=dict(args=args, kwargs=kwargs),
)
from functorch.
In functorch.make_functional{,_with_buffers}, the copied stateless module converts tensors with meta device, which does not hold data storage. This makes the fmodel use significantly less memory than the original module. Now, the nn.utils.stateless.functional_call API requires the user to pass a full model and a new copy of parameters.
Yes, to avoid using twice the amount of memory, then users need to explicitly convert tensors to meta
device.
from functorch.
Concretely, it depends on where the user wants to store their parameters:
- If they want to leave them in the module and update them in-place, then there is no need to construct a meta version of the module. This is a good workflow if the user does something like compute per-sample-gradients, and then use those to optimize the original parameters.
- If they want to update the parameters out-of-place, then the user should explicitly convert the parameters in the module to
meta
device, since those are not going to be used.
Does that alleviate the concern? If so, I'll update the migration guide to reflect this -- thank you for your feedback.
from functorch.
Thanks for the comments. Now I have no more questions about migration.
from functorch.
Related Issues (20)
- Add pytorch 1.13.1 compatibility HOT 3
- Unit Test Error When Testing vmap With Missing Module "autograd_function_db" HOT 7
- Will pmap be supported in functorh? HOT 1
- How to get only the last few layers' gradident? HOT 2
- INTERNAL_ASSERT failed HOT 4
- RuntimeError: Batching rule not implemented for aten::is_same_size. We could not generate a fallback.
- Vmap and backward hook problem HOT 1
- item() support for vmap HOT 2
- Performance drop because of not yet implemented batching rule for bincount
- Use functional models inside usual nn.Module HOT 1
- Error about using a grad transform with in-place operation is inconsistent with and without DDP HOT 1
- How to get the jacobian matrix in GCNs?
- Per-sample-gradient: Get gradient 0 when using grad(params_tograd, params) with respect to part of model's parameters HOT 1
- Can I call torch.utils.data.WeightedRandomSampler inside vmap? HOT 1
- vmap fails if your model includes full_backward_hook in pytorch2.0 HOT 1
- wrapper->level().value() <= current_level INTERNAL ASSERT FAILED at "../aten/src/ATen/functorch/ADInterpreters.cpp":39 HOT 1
- Swapping 2 columns in a 2d tensor
- vmap does not support Tensor.clone()
- Small difference between functorch grads and torch.autograd.grad
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from functorch.