- TL;DR
- Introduction: 4 mysteries of training NNs
- A simple Quadratic Stochastic Model
- The Hessian Spectrum Throughout Training
- Testing the Quadratic Approximation
- Is there Structure in the Eigenvectors components?
- Are the Eigenvectors Mostly Constant?
- A plausible hypothesis: Narrowing Valleys
- Testing the Narrowing Valley Hypothesis
- The multivariate narrowing valley model
- Conclusion and Further Questions
In this blog post I attempt to answer the following questions about the behavior of neural networks trained by supervised learning:
- why does lowering the learning rate during training result in a sharp drop in error?
- why is increasing the batch size equivalent to lowering the learning rate?
- why does the largest eigenvalue of the hessian increase as learning rate drops?
- why do negative eigenvalues persist in the hessian throughout training?
It turns out that thinking of the loss landscape of neural networks as a long valley with increasingly tightening walls (see picture below) which is stochastically translated (but not stretched) in random directions at every minibatch can elegantly answer these 4 questions.
I replicate all 4 effects above in the context of a small resnet trained on cifar10 and provide evidence for the "stochastic narrowing valley" hypothesis by first solving the dynamics of SGD-with-momentum for a stochastic quadratic function, then by computing the eigenspectrum of the hessian of the resnet at many training checkpoints.
The most crucial sections for understanding the point of this blog post are sections 2, 7 and 8, which together describe the 2D toy landscape that answers the 4 questions above. If you don't have much time, read only those sections.
This is a blog post about investigating a bunch of weird behaviours that happen when training modern neural networks. Often, newcomers to deep learning look at the size of current models, observe that the optimization problem is non-convex in dimension 10^8, and throw up their hands in despair, abandoning any attempt at building intuitions about the loss landscape of realistic networks. As we'll see, that despair might be a bit premature, as we can build very simple 1D or 2D loss landscapes that capture a lot of the properties we observe in modern deep learning, while still being very easy to visualize.
That being said, here are 4 puzzling things that happen when training networks:
- Sharp Loss Decrease at Learning Rate decreases: we observe a sharp, cliff-drop drop in loss every time we suddenly decrease the learning rate.
- Equivalence Between High Batch Size and Low Learning Rate: Increasing the batch size and decreasing learning rates have the same effect on the loss.
- Edge of Stability: The highest eigenvalue of the hessian rises precisely to match the maximum value allowed by the current learning rate.
- Persistent Negative Eigenvalues: The negative eigenvalues of the hessian persist throughout training, they don't get optimised away.
Somewhat surprisingly, it turns out that these 4 separate phenomena can be unified into a simple coherent picture of what's going on. But first, let's think a bit about why each of these mysteries is at odds with the simple Calculus 1 "gradient descent to a local minimum" picture of what's going on.
Mystery #1: In the Cal1 view, the learning rate is roughly "how big of a step" we're taking down the mountain, if we suddenly reduce our step size, it makes no sense for the loss to suddenly go down a cliff, we'd expect to just make much slower progress down the mountain.
Mystery #2: Of the four effects above, this is the easiest for the Cal1 view to accomodate, increasing batch size lowers the variance of your gradient estimate, and it makes sense that you take careful, small steps down the mountain if you are uncertain about the current slope. Yet the exact quantitative equivalence is harder to explain. Multiplying the batch size by 5 has almost the exact same effect as dividing the learning rate by 5.
Mystery #3: A decrease in learning rate somehow makes the network go towards a region of the landscape that is as sharp as the current learning rate would allow before the optimization starts diverging. What's keeping the eigenvalue from ballooning up even higher? Is the loss landscape somehow fractal in nature? And every time we decrease the learning rate we drop down into a pit that we were previously skipping over?
Mystery #4: Ordinarily the negative eigenvalues are the easiest to optimise since they are the most unstable: a step step down the slope increases your gradient magnitude towards the minimum. Yet as we'll see, a sizeable fraction of the eigenvalue spectrum at every point in training consists of negative values. What's keeping them from being optimised away?
The simplest model that exhibits the strange cliff-drop-at-lr-decrease feature is optimising a quadratic function
Where
This corresponds to the variance of
This is telling us that SGD on such a stochastic function drops down until a level roughly proportional to both the learning rate
Hence this simple model seems to exhibit both Mystery #1 and Mystery #2. Learning rate drops are always accompanied by sharp drops in loss, as the optimisation settles to a new equilibrium. And there is an almost exact equivalence between dropping the learning rate and dropping the variance of the stochastic term (the analog to increasing the batch size).
If we directly plot the loss of such a model, periodically dropping the learning rate by a factor of 10, it looks something like this:
Notice the log scale on the y axis, and the fact that each new minimum level is exactly an order of magnitude below the previous one, exactly as predicted by the theory.
As an aside, we can derive (with just slightly more effort, see the appendix) an analogous equation for the case of SGD-with-momentum, where
Which of course reduces to the non-momentum case when
We can test the formula above empirically on the test quadratic surface, and see that it accurately predicts the loss level:
The two curves have the same learning rate, they only differ in the momentum term. Notice the faster initial slope of sgd-with-momentum at the cost of a higher equilibrium loss level.
Extending this effect to n dimensions is straightforward, a generic positive definite quadratic function can be written as
Having found the behavior of SGD on our simple stochastic quadratic loss in the limit of equilibrium, we now ask what happens out of equilibrium. In this toy model, we will assume that noise is essentially irrelevant until
We therefore have an expression for the expected number of iterations it takes for SGD to descend down to a level where its dynamics are dominated entirely by the noise.
On quadratic functions for which
Hence the real impact of momentum is to allow the very small eigenvalues to have effective learning rates of
If we imagine sampling the parameters from their posterior distribution, we would expect the variance in each eigendirection to be
This suggests a mechanism through which different optimisation algorithms and noise injection schemes might be helping generalisation: they're changing
The fundamental lesson of this simple model is likely that the noise in our optimisation function is a crucial factor to keep in mind when thinking about loss landscapes. Mysteries #1 and #2 above are fundamentally noisy phenomena. This toy model is also evidence against the landscape being somehow fractal in nature given that we don't need to invoke such a complicated structure to explain the sharp loss decreases.
Now that we've derived a plausible model for what's happening in the loss landscape, let's investigate the landscape of a real neural network by explicitely computing the full Hessian at multiple points in training. Here's the setup for the experiment:
- CIFAR10 dataset without data augmentation
- Very Tiny Resnet model with GELU activations (for twice differentiability), only 26000 parameters in total
- SGD with momentum. lr=1e-1, momentum=0.97, weight-decay=1e-3
- 500 epoch total training
- lr decreases by 10 at iterations = 10000, 20000, 30000, 40000
- batch size 512
- final accuracy of 80% on validation set
Computing the full dataset Hessian is only feasible for very small models, which is why we choose such a small resnet.
Here's what the minibatch loss looks like over time:
Notice the sharp drops at iterations 10000 and 20000, corresponding to dividing the learning rate by 10. Now let's take every checkpointed network and compute its total loss on the training set, as well as the largest eigenvalue of the hessian of the full training set at that point:
The loss cliff drops become much cleaner, and we can see an extra drop appear at iteration 30000 and a very small one at 40000. There's also clear evidence of the "edge of stability" phenomenon: the top eigenvalue keeps increasing throughout training, and it shoots up quickly after each learning rate decrease.
Now let's look at the full spectrum of the Hessian, and how it evolves through training. In the figure below we're plotting a log-log graph of the sorted eigenvalues at various points in training, denoted by the iteration number.
A few observations:
- The positive eigenvalues have a roughly power-law shape to them. The 100-th biggest value is roughly 100 times smaller than the top value. There are many small eigenvalues and few top eigenvalues, but the shape of the spectrum is such that the total power at every scale is roughly the same.
- The positive spectrum mostly keeps the same overall shape through training, except for an overall translation upwards, which corresponds to the Edge-of-Stability effect.
- We still have negative eigenvalues even at the very end of training, though the shape of the spectrum does seem to change, and the total number of non-negligeable negative values steadily drops (the dropoff in the value with rank shifts leftward as training progresses)
- The pure quadratic assumption is already violated by the changing spectrum (and by the existence of negative eigenvalues).
In the derivation of the stochastic quadratic model, we made a really big assumption about the form of the stochasticity, namely that the function's shape stays the same, and only its minimum is shifted randomly from sample to sample. One could imagine other forms of stochasticity, for instance, the sharpness of the minimum might also change in addition to its minimum location. Or the minimum location might vary in a non-gaussian way. Here we test this assumption for our network at iteration = 10000, i.e. right before the first decrease in lr (though these results replicate at every other point in training).
To test this assumption, for each eigenvector
A few observations:
- The sharpness of the function doesn't change minibatch to minibatch, they're all basically the same shape, up to an irrelevant translation upwards
- The minimum locations do seem to be distributed normally, no weird surprises or outliers here.
- Overall the toy stochastic quadratic model seems to perfectly describe what's going on here.
The positive eigenvalue directions behave as the toy model expected, but what about the negative directions? Doing the same procedure as above, this is what we get, again plotting line searches in a particular eigenvector, with different curves representing different minibatches. Blue points are the global minimums of the functions.
- Again the function shape remains fairly consistent batch-to-batch, almost every function has two local minima, and they all cluster around the same two points on the x-axis.
- one of the two local minima is clearly lower than the other, but not all minibatches agree on which of the two is the correct one.
- The most surprising fact here is that these directions have not yet been optimised away. In these plots the middle point represents the unchanged parameters of the network, i.e
$f_i(\theta + 0 v_\lambda)$ , and we see that this point lies at a local maximum of the function. These negative eigenvalues are also quite large, it's not the case that this direction is just too flat for SGD to make progress. Some unknown mechanism is keeping the network at a local maximum in this direction.
The toy quadratic model has a free parameter
So we see a decrease in standard deviation for larger eigenvalues, but notice the log scale on the x-axis: an order of magnitude increase in eigenvalue gives us a measly
Now for each eigenvalue we plug in the relevant factors into the equation for
Perhaps surprisingly, higher eigenvalues tend to oscillate slightly more at equilibirum (apart from a few outliers at the very high end). But again the log scale on the x-axis implies that there's remarkably little change for the wide range of values that the eigenvalues pass through.
Given the variance
We now plot this
Notice the log scales on both the y and x axis this time. The very highest eigenvalues take almost no time at all to be optimised, whereas the lowest ones take as much as
This graph gives us a plausible hypothesis for why we need to train a high learning rates: the low eigenvalues take a long time to reach equilibrium. Therefore we can see the tradeoff between low lr and high lr as follows:
- We need high lr in order to reduce the time it takes to reach equilibirum in the low eigenvalue directions.
- But we need low lr in order to decrease variance of the oscillations at equilibrium, and reach a lower loss level.
If we lower the learning rate too quickly, we'll get a sudden drop in loss as all the eigenvalues that were already at equilibrium drop to an even lower level, but the price we pay is that those directions that weren't yet at equilibrium now will take much longer to get there, because the learning rate is much smaller, and so we've crippled our long-term potential.
Once all directions have reached equilibrium, no more progress is possible at this learning rate, and we need to lower the learning rate or increase the batch size to make any further progress. However, as we'll see in section 7, this effect is not enough to completely explain why small learning rates are important, it does seem to explain some of the effect, but only a fraction of it.
Now that we have the variance
As we can see, something goes terribly wrong: the toy model is predicting a total loss decrease an order of magnitude higher than the total loss of the network. Cross-Entropy is bounded below by 0, hence it's impossible to get a loss decrease larger than the red line. I don't know what's going on here, perhaps the addition of gradient clipping and weight decay is messing up the simple sgd-with-momentum math, or the noise is non-gaussian in a way that makes our assumptions break down. (The noise merely being correlated between dimensions wouldn't be enough to explain this.)
Switching gears a bit for the moment, let's look at the eigenvectors corresponding to each of our eigenvalues and try to figure out if there's any internal structure to them that we can find. An eigenvector
To do this, for each vector
Let's do the same for negative eigenvalues, red curves correspond to high absolute value of the eigenvalue.
A few points:
- There's definitely a clear pattern where the eigenvectors corresponding to large eigenvalues are much more sharply distributed than those from smaller eigenvalues.
- A few outlier directions have as much as 80% of their squared sum being concentrated in merely 100 components of the network
- The same basic pattern happens with both negative and positive eigenvalues.
- A plausible hypothesis for what's happening here is that total loss is very sensitive to a minority of the parameters, probably corresponding to the biases of the network and the learned batchnorm variances, and this is showing up in the eigenvectors.
Now we turn to the question of figuring out how the eigenvectors of the Hessian change over the course of optimization. Visualizing changes in a 26000 by 26000 matrix is non-trivial, so we need a bit of inventiveness to extract some interesting results here.
We will consider the Eigenvectors of the spectrum at 4 different points in training: iterations 2000, 10000, 20000, and 30000. Given a particular eigenvector at one of these points, we are interested in asking how close it is to the eigenvectors at the previous point in training. And in particular we want to know how close this eigenvector is to the old eigenvectors with roughly the same eigenvalue as it was.
Meaning, is the new eigenvector just changing in a random direction in parameter space, or is its change biased towards directions that had roughly the same eigenvalues?
To be specific, consider the eigenvector
In the images below, the x axis corresponds to
To take into account the fact that there are many more small eigenvalues than large ones, we plot the power
comments:
- The eigenvectors are in fact changing significantly. No change at all would correspond to each curve being a single infinitely thin spike, since we'd be projecting each eigenvector onto vectors orthogonal to it.
- Vectors with a new value
$\lambda$ are most similar to the old vector with value$\lambda$ , i.e. each curve in the graphs have maximums that correspond to their own$\lambda$ values - There's a smooth and predictable bias towards the change happening in vectors with neighboring eigenvalues. A vector with a high eigenvalue won't suddenly change in directions with low eigenvalue, it'll only rotate into directions that are in a neighborhood of itself in terms of eigenvalue. This is surprising, as we might've imagined a "spike-and-slab" model, where an eigenvector stays most similar to itself, but apart from that, just rotates in a random direction in parameter space. This is not what seems to happen here.
- The vectors seem to change the most in the early iterations, we see that in the graph from iteration 20000 to 30000, many more of the eigenvectors have sharp power curves, corresponding to the fact that they're mostly staying similar to their old eigenvectors.
While the stochastic quadratic approximation predicts that loss should drop when we drop the learning rate, it does not predict that the maximum eigenvalue should rise, nor does it predict that negative eigenvalues remain until the end of training. Can we build an example of a non-quadratic 2D function which exhibits these two properties, or do we need high dimensions to explain these phenomena? The answer turns out to be that yes, we can construct such a function. Consider
This function has a global minimum at
Something interesting is happening with the x component of the gradient, when
So the
What about the negative eigenvalues? Does this simplified model predict that we'll observe negative values? Let's plot the regions of
Again we see that only a thin band around
This simplified model also provides a possible explanation for the "Edge Of Stability" effect (eigenvalues increase as we decrease learning rate): as we decrease the learning rate, we decrease our variability in
- We begin optimisation at some random point in the landscape, gradient descent quickly descents down high eigenvalue directions until it gets into equilibrium with the noise in those directions, then we oscillate in the high
$\lambda$ directions with some variance$s^2$ . However, the low-but-positive-$\lambda$ directions take longer to get optimised, and they benefit from keeping the learning rate higher for longer. - The oscillation in the high eigenvalue direction is preventing optimisation from occuring in the narrowing directions, because some significant fraction of oscillations bring the network into regions of the landscape where the gradient is pushing it away from the minimum.
- While we oscillate in the high-
$\lambda$ directions, negative eigenvalues don't go away because we keep jumping over the narrow region where the negative values would disappear. - When we finally decrease the learning rate (or increase the batch size) by some fixed amount, two things happen: first, the high-
$\lambda$ directions drop to a lower equilibrium level, which quickly drops the loss. Then, because we're now oscillating at a lower level, we can correctly "see" the gradient in the narrowing valley directions, this lets us optimise those directions, leading to a further drop in loss. - We settle into a new equilibrium at some point down the narrow valley, the largest eigenvalue of the Hessian increases to reflect the narrowing walls of the new equilibrium point (Edge of Stability effect), and the whole process repeats at the next learning rate drop.
This story seems to concisely explain the "4 mysteries" of training neural networks that we considered at the beginning of this post. Sharp loss decreases, high-batch-low-lr equivalence, the edge of stability, and persistent negative eigenvalues are all effects that naturally fall out of a stochastic landscape where some directions have narrowing valleys.
The narrowing valley hypothesis makes one unambiguous prediction that we should be able to test in realistic networks: if we optimise the loss function in the subspace defined by the highest eigenvalues of the hessian, the number of negative eigenvalues of the hessian at that local minimum should decrease. In the context of the toy function from the previous section, that corresponds to finding the thin blue strip in the previous figure. i.e. optimising the y-dimension should bring us within a region of parameter space where there are no more negative eigenvalues.
Note also that this is a non-trivial prediction. If the loss landscape could be well approximated merely by a quadratic function where some of the eigenvalues were negative (i.e. the typical saddle shape), then we would not expect that minimizing the positive eigendirections would have any effect at all on the negative spectrum. Nor would we expect the positive directions to influence the negative ones if the loss could be factorised into
To test this prediction in our small but non trivial model. We pick again the iteration=10000 checkpointed network, and perform full batch SGD for 1000 iterations with lr=1e-3 and momentum=0.9 in the subspace defined by the top n eigenvectors, where n will vary on a logarithmic scale from 6 to 2000. After having found the local minimum within that subspace, we compute the bottom 2000 negative eigenvalues of the network at that point (computing the whole Hessian is too expensive here), and plot the negative spectrum for multiple values of n, the number of top eigenvectors we optimise:
And we see a very robust decrease in the magnitude of negative eigenvalues as we optimise more and more high eigenvalues.
Suppose, that we expand the model into
This model has the same minimum as the usual quadratic approximation, and has the same hessian at the minimum itself, but the hessian changes for small displacements from the minimum by an amount given by:
These changes are to first order in
From the perturbation theory of linear operators, the first order change in an eigenvalue
where
Each eigenvalue
Future directions:
- Why does it take so long to optimise neural networks? Can use the narrowing valley hypothesis to build a model which lets us accurately predict the loss given a learning rate schedule and optimisation algorithm?
- Are there many different narrowing valleys we could fall down into? i.e. if we're oscillating at equilibrium at some learning rate, does the particular point at which we decide to drop the learning rate send us down different narrowing valleys? Or does it not matter?
- How much support do the high eigenvectors have over the data? i.e. are the high eigenvalues due to all datapoints having large dependence on those directions, or do a small number of points have an outsized impact on them?
- Do these results generalise to larger Resnets, what about transformer architectures?
- Are the very small (in absolute terms) eigenvalues important to minimize? Could we restrict the optimisation to a few of the highest eigenvectors as well as the negative vectors, and not lose meaningful performance?
- What, if any, is the connection with the Lottery Ticket Hypothesis?
- Which features of the loss landscape are responsible for overfitting? i.e. should we want to go down the narrowing valley, or is going down the valley the price we pay for needing to find the minimum of the high eigenvalues?
- What features of the data and/or the learned representations are responsible for the high/low eigenvalues, and the narrowing effect of the landscape?
- Can we determine an optimal learning rate schedule from our knowledge of the eigenvalue spectrum and the noise level in each dimension?
- Can we design architectures whose loss functions exhibit less of a narrowing effect, thereby being trainable with higher learning rates?
- Can we design architectures where most of the high-
$\lambda$ vectors are sequestered to a small number of network components? This would allow us to use much higher learning rates in the rest of the network without risking divergence. - How do we efficiently minimise noise in the larger eigenvalues while still making large steps in the low-eigenvalue directions?
- do the eigenvectors of the network change when oscillating at equilibrium? what causes their change?
- can we predict oscillations in and out of the valley due to sgd noise?
- does weight decay and/or grad clipping meaningfully affect these results?
When adding momentum, the equations become:
If we assume that subsequent