Comments (2)
is this a legal way to solve this? it doesn't give me an error but I am very unsure why this now works.
def test_resnet_2(new_params):
def interpolate(alpha):
with torch.no_grad():
for i, (name, old_p) in enumerate(named_params_data):
new_p = new_params[i]
parame_names = name.split(".")
current = model_resnet
for p in parame_names[:-1]:
current = getattr(current, p)
setattr(current, parame_names[-1], torch.nn.Parameter(old_p + alpha*new_p))
out = model_resnet(sample_data)
for i, (name, old_p) in enumerate(named_params_data):
parame_names = name.split(".")
current = model_resnet
for p in parame_names[:-1]:
current = getattr(current, p)
setattr(current, parame_names[-1], torch.nn.Parameter(old_p))
return out
return interpolate
model_resnet.eval()
to_vamp_resnet = test_thing2(rand_tensor)
test_out2 = vmap(to_vamp_resnet)(alphas)
EDIT: found an even simple solution. This is the correct approach, right?
def test_resnet_4(new_params, sample_data, model_resnet):
func_model, params, buff = make_functional_with_buffers(model_resnet, disable_autograd_tracking=True)
def interpolate(alpha):
with torch.no_grad():
interpol_params = [torch.nn.Parameter(old_p + alpha*new_params[i]) for i, old_p in enumerate(params)]
out = func_model(interpol_params, buff, sample_data)
return out
return interpolate
model_resnet.eval()
to_vamp_resnet = test_resnet_4(rand_tensor, sample_data, model_resnet)
test_out2 = vmap(to_vamp_resnet)(alphas)
from functorch.
Hi @LeanderK! Thanks for the interesting issue! Since it sounds like this works, that's a totally fine way of doing it!
One thing that might come up is if you do N
runs of this model (instead of 1), it will be faster to do something similar to the ensembling API since in your version you would be building the new parameters N
times and this way you'll only build them once and then combine them. This is also useful if you want to train the model (batch norm should work with the ensemble)
For this use case, since it looks like you want to have very specific initializations, it this might be better to riff on the idea of the ensemble API
def test_resnet_4(func_model, buff, sample_data):
def interpolate(interpol_params):
with torch.no_grad():
out = func_model(interpol_params, buff, sample_data)
return out
return interpolate
model_resnet.eval()
func_model, params, buff = make_functional_with_buffers(model_resnet, disable_autograd_tracking=True)
interpol_params = [[torch.nn.Parameter(old_p + alpha*rand_tensor[i]) for i, old_p in enumerate(params)] for alpha in alphas]
interpol_params = [torch.stack(i) for i in zip(*interpol_params)] # this is basically what the ensemble API is doing
to_vmap_resnet = test_resnet_4(func_model, buff, sample_data)
test_out2 = vmap(to_vmap_resnet)(interpol_params)
Then, if you want to train, you can also expand the buffers and vmap across them along with interpol_params so that batch norm works
Hope that helps! We are also looking at changing the module API to help rationalize some of the functorch API with the PyTorch API soon. If you're using the nightly build, I can point you to the new API if you're curious
from functorch.
Related Issues (20)
- Add pytorch 1.13.1 compatibility HOT 3
- Unit Test Error When Testing vmap With Missing Module "autograd_function_db" HOT 7
- Will pmap be supported in functorh? HOT 2
- How to get only the last few layers' gradident? HOT 2
- [Question] Packaging policy for `functorch` and `torch.func` HOT 5
- INTERNAL_ASSERT failed HOT 4
- RuntimeError: Batching rule not implemented for aten::is_same_size. We could not generate a fallback.
- Vmap and backward hook problem HOT 1
- item() support for vmap HOT 2
- Performance drop because of not yet implemented batching rule for bincount
- Use functional models inside usual nn.Module HOT 1
- Error about using a grad transform with in-place operation is inconsistent with and without DDP HOT 1
- How to get the jacobian matrix in GCNs?
- Per-sample-gradient: Get gradient 0 when using grad(params_tograd, params) with respect to part of model's parameters HOT 1
- Can I call torch.utils.data.WeightedRandomSampler inside vmap? HOT 1
- vmap fails if your model includes full_backward_hook in pytorch2.0 HOT 1
- wrapper->level().value() <= current_level INTERNAL ASSERT FAILED at "../aten/src/ATen/functorch/ADInterpreters.cpp":39 HOT 1
- Swapping 2 columns in a 2d tensor
- vmap does not support Tensor.clone()
- Small difference between functorch grads and torch.autograd.grad
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from functorch.