Git Product home page Git Product logo

Comments (5)

zou3519 avatar zou3519 commented on June 3, 2024

In 2.0 you can access the APIs from both torch.func and functorch. This will be true for the foreseeable future (i.e., we will preserve BC for a few releases of PyTorch). However, there will be differences between the torch.func.* and functorch.* APIs

In general, we're deprecating the functorch.* APIs in favor of the torch.func.* APIs. As a part of this deprecation, we're moving away from functorch.make_functional and consolidating on PyTorch's NN stateless API. More details over at pytorch/pytorch#91811

from functorch.

XuehaiPan avatar XuehaiPan commented on June 3, 2024

In general, we're deprecating the functorch.* APIs in favor of the torch.func.* APIs.

@zou3519 Thanks for the comment.

As a part of this deprecation, we're moving away from functorch.make_functional and consolidating on PyTorch's NN stateless API. More details over at pytorch/pytorch#91811

One more question about the memory usage of PyTorch's NN stateless API.

We will prefer torch.func.functional_call over functorch.make_functional in the future:

import functorch
import torch

model = ...  # build NN module

# functional_call
params_and_buffers_dict = ...  # extract parameters or user-defined tensor dicts
output = torch.func.functional_call(model, params_and_buffers_dict, args=args, kwargs=kwargs)

# make_functional
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
output = fmodel(params, buffers, *args, **kwargs)

In functorch.make_functional{,_with_buffers}, the copied stateless module converts tensors with meta device, which does not hold data storage. This makes the fmodel use significantly less memory than the original module. Now, the nn.utils.stateless.functional_call API requires the user to pass a full model and a new copy of parameters. That is twice the memory usage. This memory problem may be exacerbated when multi-process communication is required.

For example, stateless functional call over RPC:

import torch
import torch.distributed.rpc as rpc

model = ...  # build NN module

# functional_call
params_and_buffers_dict = ...  # extract parameters or user-defined tensor dicts
output = rpc.rpc_sync(
    'worker1',
    torch.func.functional_call,
    args=(
        model,  # the original parameters also need communication
        params_and_buffers_dict,
    ),
    kwargs=dict(args=args, kwargs=kwargs),
)

# make_functional
fmodel, params, buffers = functorch.make_functional_with_buffers(model)
output = rpc.rpc_sync(
    'worker1',
    fmodel,  # small serialization and communication overhead
    args=(params, buffers, *args),
    kwargs=kwargs,
)

In order to have less communication cost, users need to explicitly convert the tensors to meta device before stateless functional call on remote workers:

import torch
import torch.distributed.rpc as rpc

model = ...  # build NN module

# functional_call
params_and_buffers_dict = ...  # extract parameters or user-defined tensor dicts
output = rpc.rpc_sync(
    'worker1',
    torch.func.functional_call,
    args=(
        model.to('meta'),  # convert to meta device    
        params_and_buffers_dict,
    ),
    kwargs=dict(args=args, kwargs=kwargs),
)

from functorch.

zou3519 avatar zou3519 commented on June 3, 2024

In functorch.make_functional{,_with_buffers}, the copied stateless module converts tensors with meta device, which does not hold data storage. This makes the fmodel use significantly less memory than the original module. Now, the nn.utils.stateless.functional_call API requires the user to pass a full model and a new copy of parameters.

Yes, to avoid using twice the amount of memory, then users need to explicitly convert tensors to meta device.

from functorch.

zou3519 avatar zou3519 commented on June 3, 2024

Concretely, it depends on where the user wants to store their parameters:

  • If they want to leave them in the module and update them in-place, then there is no need to construct a meta version of the module. This is a good workflow if the user does something like compute per-sample-gradients, and then use those to optimize the original parameters.
  • If they want to update the parameters out-of-place, then the user should explicitly convert the parameters in the module to meta device, since those are not going to be used.

Does that alleviate the concern? If so, I'll update the migration guide to reflect this -- thank you for your feedback.

from functorch.

XuehaiPan avatar XuehaiPan commented on June 3, 2024

Thanks for the comments. Now I have no more questions about migration.

from functorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.