Git Product home page Git Product logo

test-completion-transformer's Introduction

jdepoix's GitHub stats

test-completion-transformer's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

kaizhiyu

test-completion-transformer's Issues

Understanding InverseSquareRootLR

@jdepoix, I have been looking around at learning rate schedulers, and I saw your very nice implementation here. I believe I'm either misunderstanding something or there might be a small bug in your scheduler. Thanks for any thought you put toward this!

Warmup Phase

I have been using your scheduler code with 10,000 warm_up_steps and a starting lr of 0.0001.

  • According to your scheduler, the model lr will start at 0.0001 / 10,000 = 1.0e-8 in self._lr_steps.
  • In get_lr, we multiply self.last_epoch by self.lr_step, so when self.last_epoch is 0, we get 0 for the model lr.
  • When self.last_epoch is 10 the model lr is 1.0e-7.
  • At step 100, the lr will be 1.0e-6.
  • At step 1,000 it will be 1.0e-5.
  • At step 10,000 we will recover the initial lr of 1.0e-4.

Decay Phase

After hitting the number of warm_up_steps in the example with 10,000 warm_up_steps and a starting lr of 0.0001, the lr will start to decay with self._decay_factors set to 0.0001 * sqrt(10,000) = 0.01. This means that:

  • At step 10,000, the scheduler will return 0.01 * sqrt(10,000) = 1.
  • At step 11,000, the scheduler will return 0.01 * sqrt(11,000) = 1.04880884817.
  • At step 20,000, the scheduler will return 0.01 * sqrt(20,000) = 1.41421356237.
  • At step 50,000, the scheduler will return 0.01 * sqrt(50,000) = 2.2360679775.
  • At step 100,000 the scheduler will return 0.01 * sqrt(100,000) = 3.16227766017.
  • At step 1,000,000 the scheduler will return 0.01 * sqrt(1,000,000) = 10.
  • At step 10,000,000 the scheduler will return 0.01 * sqrt(10,000,000) = 31.6227766017.
  • At step 100,000,000 the scheduler will return 0.01 * sqrt(100,000,000) = 100.

Issue

The warmup phase makes perfect sense to me. However, should we be returning the initial learning rate scaled by the values reported here in the decay phase? For example, at step 11,000, I would think we would want to return the initial lr of 0.0001 / 1.04880884817 = 0.00009534625. At step 20,000 we would then return 0.0001 / 1.41421356237 = 0.00007071067. This way, the learning rate shrinks/decays according to the inverse sqrt function.

Proposed Change

Here's your current code. Not a lot changes with what I'm proposing.

from torch.optim.lr_scheduler import _LRScheduler

class InverseSquareRootLR(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, last_epoch=-1):
        if warmup_steps <= 0:
            raise ValueError('warmup_steps must be > 0')
        self._warmup_steps = warmup_steps
        self._lr_steps = [param_group['lr'] / warmup_steps for param_group in optimizer.param_groups]
        self._decay_factors = [
            param_group['lr'] * warmup_steps ** 0.5 for param_group in optimizer.param_groups
        ]

        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch < self._warmup_steps:
            return [self.last_epoch * lr_step for lr_step in self._lr_steps]
        else:
            return [decay_factor * self.last_epoch ** -0.5 for decay_factor in self._decay_factors]

Modified code (note what happens in get_lr when self.last_epoch >= self._warmup_steps):

from torch.optim.lr_scheduler import _LRScheduler

class InverseSquareRootLR(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, last_epoch=-1):
        if warmup_steps <= 0:
            raise ValueError('warmup_steps must be > 0')
        self._warmup_steps = warmup_steps
        self._lr_steps = [param_group['lr'] / warmup_steps for param_group in optimizer.param_groups]
        self._decay_factors = [
            param_group['lr'] * warmup_steps ** 0.5 for param_group in optimizer.param_groups
        ]

        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch < self._warmup_steps:
            return [self.last_epoch * lr_step for lr_step in self._lr_steps]
        else:
            decay_steps = [decay_factor * self.last_epoch ** -0.5 for decay_factor in self._decay_factors]
            return [
                param_group['lr'] / decay_step for param_group, decay_step in list(zip(optimizer.param_groups, decay_steps))
            ]

Warm regards :)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.