jdepoix / test-completion-transformer Goto Github PK
View Code? Open in Web Editor NEWA syntax-aware Transformer Model for Test Completion
A syntax-aware Transformer Model for Test Completion
@jdepoix, I have been looking around at learning rate schedulers, and I saw your very nice implementation here. I believe I'm either misunderstanding something or there might be a small bug in your scheduler. Thanks for any thought you put toward this!
I have been using your scheduler code with 10,000 warm_up_steps
and a starting lr
of 0.0001.
lr
will start at 0.0001 / 10,000 = 1.0e-8 in self._lr_steps
.get_lr
, we multiply self.last_epoch
by self.lr_step
, so when self.last_epoch
is 0, we get 0 for the model lr
.self.last_epoch
is 10 the model lr
is 1.0e-7.lr
will be 1.0e-6.lr
of 1.0e-4.After hitting the number of warm_up_steps
in the example with 10,000 warm_up_steps
and a starting lr
of 0.0001, the lr
will start to decay with self._decay_factors
set to 0.0001 * sqrt(10,000) = 0.01. This means that:
The warmup phase makes perfect sense to me. However, should we be returning the initial learning rate scaled by the values reported here in the decay phase? For example, at step 11,000, I would think we would want to return the initial lr
of 0.0001 / 1.04880884817 = 0.00009534625. At step 20,000 we would then return 0.0001 / 1.41421356237 = 0.00007071067. This way, the learning rate shrinks/decays according to the inverse sqrt function.
Here's your current code. Not a lot changes with what I'm proposing.
from torch.optim.lr_scheduler import _LRScheduler
class InverseSquareRootLR(_LRScheduler):
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
if warmup_steps <= 0:
raise ValueError('warmup_steps must be > 0')
self._warmup_steps = warmup_steps
self._lr_steps = [param_group['lr'] / warmup_steps for param_group in optimizer.param_groups]
self._decay_factors = [
param_group['lr'] * warmup_steps ** 0.5 for param_group in optimizer.param_groups
]
super().__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch < self._warmup_steps:
return [self.last_epoch * lr_step for lr_step in self._lr_steps]
else:
return [decay_factor * self.last_epoch ** -0.5 for decay_factor in self._decay_factors]
Modified code (note what happens in get_lr
when self.last_epoch
>= self._warmup_steps
):
from torch.optim.lr_scheduler import _LRScheduler
class InverseSquareRootLR(_LRScheduler):
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
if warmup_steps <= 0:
raise ValueError('warmup_steps must be > 0')
self._warmup_steps = warmup_steps
self._lr_steps = [param_group['lr'] / warmup_steps for param_group in optimizer.param_groups]
self._decay_factors = [
param_group['lr'] * warmup_steps ** 0.5 for param_group in optimizer.param_groups
]
super().__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch < self._warmup_steps:
return [self.last_epoch * lr_step for lr_step in self._lr_steps]
else:
decay_steps = [decay_factor * self.last_epoch ** -0.5 for decay_factor in self._decay_factors]
return [
param_group['lr'] / decay_step for param_group, decay_step in list(zip(optimizer.param_groups, decay_steps))
]
Warm regards :)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.