Git Product home page Git Product logo

Comments (8)

DavidBert avatar DavidBert commented on August 18, 2024 1

Thanks @eracah! The requires_grad=False is indeed the easiest fix sor far for my specific problem.
I tested it and it works as expected.
Thanks for the help and all your work, Composer is a very cool library!

from composer.

mvpatel2000 avatar mvpatel2000 commented on August 18, 2024

Hm... yes I think it might be not properly saving the subset of parameters if the group is not all params.

Code to save:

if num_param_groups > 1:
if not fsdp_config.use_orig_params:
raise RuntimeError(
'Multiple optimizer groups with FSDP are only supported with '
'use_orig_params=True.',
)
# optimizer.param_groups do not contain parameter names which are needed
# to keep track of the different parameters in each group
# so we use the pointers between model.parameters() and model.named_parameters()
# to get the names of the parameters within optimizer.param_groups
param_pointer_to_param_name = {id(p): n for n, p in model.named_parameters()}
param_name_to_group_num = {}
group_num_to_param_group_info = {}
for group_num in range(len(optim.param_groups)):
# Need to in-line to avoid a reference which causes FSDP to allocate extra GPU memory
# group = optim.param_groups[group_num]
for param_num in range(len(optim.param_groups[group_num]['params'])):
# Need to in-line to avoid a reference which causes FSDP to allocate extra GPU memory
# param = optim.param_groups[group_num]['params'][param_num]
param_name_to_group_num[param_pointer_to_param_name[id(
optim.param_groups[group_num]['params'][param_num],
)]] = group_num
# this includes optimizer-specific values like lr, eps
# this will be used as the kwargs for the optim param groups later
optimizer_specific_group_info = {
k: v for k, v in optim.param_groups[group_num].items() if k != 'params'
}
group_num_to_param_group_info[group_num] = optimizer_specific_group_info
else:
optimizer_specific_info = {k: v for k, v in optim.param_groups[0].items() if k != 'params'}

@sashaDoubov what do you think? iirc you added this bit

@DavidBert do you have a repro we can use here?

from composer.

DavidBert avatar DavidBert commented on August 18, 2024

Thanks for the quick answer!
This code should demonstrate the undesired behavior:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# optimizer
from torch.optim import Adam

from composer import Trainer
from composer.models import ComposerClassifier
from composer.utils import dist
import copy


class Model(nn.Module):
    """Toy convolutional neural network architecture in pytorch for MNIST."""

    def __init__(self, num_classes: int = 10):
        super().__init__()

        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(1, 16, (3, 3), padding=0)
        self.conv2 = nn.Conv2d(16, 32, (3, 3), padding=0)
        self.bn = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32 * 16, 32)
        self.fc2 = nn.Linear(32, num_classes)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn(out)
        out = F.relu(out)
        out = F.adaptive_avg_pool2d(out, (4, 4))
        out = torch.flatten(out, 1, -1)
        out = self.fc1(out)
        out = F.relu(out)
        return self.fc2(out)

transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
sampler = dist.get_sampler(dataset, shuffle=True)

dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

optimizer = Adam(Model().conv1.parameters(), lr=1e-3)
unwrapped_optimizer = copy.deepcopy(optimizer)

trainer = Trainer(
    model=ComposerClassifier(module=Model(), num_classes=10),
    train_dataloader=dataloader,
    max_duration="2ep",
    optimizers=optimizer,
    fsdp_config={"sharding_strategy":"_HYBRID_SHARD_ZERO2",
                    "mixed_precision":"PURE",
                    "backward_prefetch":"BACKWARD_PRE"},
)


with trainer.state.model.module.summon_full_params(trainer.state.model.module):
    nb_parameters_before_fsdp = len(unwrapped_optimizer.param_groups[0]["params"])
    nb_parameters_after_fsdp = len(trainer.state.optimizers[0].param_groups[0]["params"])
    assert nb_parameters_before_fsdp == nb_parameters_after_fsdp, f"expected {nb_parameters_before_fsdp} but got {nb_parameters_after_fsdp}"

from composer.

mvpatel2000 avatar mvpatel2000 commented on August 18, 2024

CC: @eracah can you take a look?

from composer.

eracah avatar eracah commented on August 18, 2024

Good find, @DavidBert ! It looks like here we re-init the optimizer with all parameters even if you created it with only a subset of params (for optimizers with 1 param_group). We'll file a bug to create using just the parameters from the original optimizer. For now to unblock yourself you could make a param_group with the Model().conv1.parameters() and another param_group with the rest of the params using the Optimizer.add_param_group function. This will create >1 param group, which will then guarantee the optimizer is recreated with the correct setup. If all you want to do is freeze the other parameters you can use the optimizer as normal, but set requires_grad=False for all parameters. That might be simpler than creating an optimizer with a subset of the parameters

Lmk if that helps unblock you!
cc: @sashaDoubov @mvpatel2000

from composer.

eracah avatar eracah commented on August 18, 2024

No problem! Glad we could unblock you, @DavidBert !

from composer.

mvpatel2000 avatar mvpatel2000 commented on August 18, 2024

Should now be fixed in general

from composer.

DavidBert avatar DavidBert commented on August 18, 2024

Thank you guys!

from composer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.