Git Product home page Git Product logo

Comments (2)

LeanderK avatar LeanderK commented on June 11, 2024

is this a legal way to solve this? it doesn't give me an error but I am very unsure why this now works.

def test_resnet_2(new_params):
    def interpolate(alpha):
        with torch.no_grad():
            for i, (name, old_p) in enumerate(named_params_data):
                new_p = new_params[i]
                parame_names = name.split(".")
                current = model_resnet
                for p in parame_names[:-1]:
                    current = getattr(current, p)
                setattr(current, parame_names[-1], torch.nn.Parameter(old_p + alpha*new_p))
                
            out = model_resnet(sample_data)
            
            for i, (name, old_p) in enumerate(named_params_data):
                parame_names = name.split(".")
                current = model_resnet
                for p in parame_names[:-1]:
                    current = getattr(current, p)
                setattr(current, parame_names[-1], torch.nn.Parameter(old_p))
            return out
    return interpolate

model_resnet.eval()
to_vamp_resnet = test_thing2(rand_tensor)
test_out2 = vmap(to_vamp_resnet)(alphas)

EDIT: found an even simple solution. This is the correct approach, right?

def test_resnet_4(new_params, sample_data, model_resnet):
    func_model, params, buff = make_functional_with_buffers(model_resnet, disable_autograd_tracking=True)
    def interpolate(alpha):
        with torch.no_grad():
            interpol_params = [torch.nn.Parameter(old_p + alpha*new_params[i]) for i, old_p in enumerate(params)]
                
            out = func_model(interpol_params, buff, sample_data)
            return out
    return interpolate

model_resnet.eval()
to_vamp_resnet = test_resnet_4(rand_tensor, sample_data, model_resnet)
test_out2 = vmap(to_vamp_resnet)(alphas)

from functorch.

samdow avatar samdow commented on June 11, 2024

Hi @LeanderK! Thanks for the interesting issue! Since it sounds like this works, that's a totally fine way of doing it!

One thing that might come up is if you do N runs of this model (instead of 1), it will be faster to do something similar to the ensembling API since in your version you would be building the new parameters N times and this way you'll only build them once and then combine them. This is also useful if you want to train the model (batch norm should work with the ensemble)

For this use case, since it looks like you want to have very specific initializations, it this might be better to riff on the idea of the ensemble API

def test_resnet_4(func_model, buff, sample_data):
  def interpolate(interpol_params):
      with torch.no_grad():
          out = func_model(interpol_params, buff, sample_data)
          return out
  return interpolate

model_resnet.eval()

func_model, params, buff = make_functional_with_buffers(model_resnet, disable_autograd_tracking=True)
interpol_params = [[torch.nn.Parameter(old_p + alpha*rand_tensor[i]) for i, old_p in enumerate(params)] for alpha in alphas]
interpol_params = [torch.stack(i) for i in zip(*interpol_params)] # this is basically what the ensemble API is doing
to_vmap_resnet = test_resnet_4(func_model, buff, sample_data)
test_out2 = vmap(to_vmap_resnet)(interpol_params)

Then, if you want to train, you can also expand the buffers and vmap across them along with interpol_params so that batch norm works

Hope that helps! We are also looking at changing the module API to help rationalize some of the functorch API with the PyTorch API soon. If you're using the nightly build, I can point you to the new API if you're curious

from functorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.