Comments (2)
functorch.grad
computes gradients w.r.t. to the first argument you pass it. This is currently params
(all parameters in the model), but the solution is to pass it only the parameters that you want gradients of.
Some pseudocode.
from functorch import make_functional_with_buffers, vmap, grad
fmodel, params, buffers = make_functional_with_buffers(net,disable_autograd_tracking=True)
def compute_loss_stateless_model (last_layers_params, first_layers_params, buffers, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
# pseudocode: we need to put the params together back into a single params list
# that fmodel can understand
params = (*first_layers_params, *last_layers_params)
predictions = fmodel(params, buffers, batch)
loss = criterion(predictions, targets)
return loss
ft_compute_grad = grad(compute_loss_stateless_model)
# pseudocode: we need to split the params we want to compute gradients of from the params we don't
# want to compute gradients of.
first_layers_params, last_layers_params = partition(params)
gradinet = ft_compute_grad(last_layers_params, first_layers_params, buffers, train_poi_set[0][0].cuda(), torch.tensor(train_poi_set[0][1]).cuda())
from functorch.
@zou3519 I have the similar question. But it's about jacrev. For example, I only want to compute the jacobi respect to the last layers. Can this work?
from functorch.
Related Issues (20)
- Add pytorch 1.13.1 compatibility HOT 3
- Unit Test Error When Testing vmap With Missing Module "autograd_function_db" HOT 7
- Will pmap be supported in functorh? HOT 2
- [Question] Packaging policy for `functorch` and `torch.func` HOT 5
- INTERNAL_ASSERT failed HOT 4
- RuntimeError: Batching rule not implemented for aten::is_same_size. We could not generate a fallback.
- Vmap and backward hook problem HOT 1
- item() support for vmap HOT 2
- Performance drop because of not yet implemented batching rule for bincount
- Use functional models inside usual nn.Module HOT 1
- Error about using a grad transform with in-place operation is inconsistent with and without DDP HOT 1
- How to get the jacobian matrix in GCNs?
- Per-sample-gradient: Get gradient 0 when using grad(params_tograd, params) with respect to part of model's parameters HOT 1
- Can I call torch.utils.data.WeightedRandomSampler inside vmap? HOT 1
- vmap fails if your model includes full_backward_hook in pytorch2.0 HOT 1
- wrapper->level().value() <= current_level INTERNAL ASSERT FAILED at "../aten/src/ATen/functorch/ADInterpreters.cpp":39 HOT 1
- Swapping 2 columns in a 2d tensor
- vmap does not support Tensor.clone()
- Small difference between functorch grads and torch.autograd.grad
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from functorch.