- Luke Wood
The Perfect Solution to Loss Reduction in Deep Learning!
In deep learning, the most common optimization algorithm used in training is mini-batch stochastic gradient descent algorithm. In this algorithm, a random sampling of data points from your training set is sampled and weight gradients are computed with respect to the entire batch.
When you compute the loss for a batch of samples. An example we will consider is the classification example where your mini-batch consists of
64 images and their
64 corresponding labels. When computing the loss, you initially will get a batch of losses, a vector of dimension
64 in our example above.
When computing the gradient of each weight in your network with respect to this loss, you consider a scalar loss; not the
64 dimensional loss vector described above. But what strategy do you employ to convert from the
64 dimensional vector to the scalar? There are two obvious candidate choices:
mean(). These approaches are both capable of yielding identical numerical results, and may seem identical on the surface, but there is actually an important tradeoff to consider when picking your loss reduction strategy.
Batch-level loss reporting
Let's start by discussing loss reporting.
The main benefit of using the
mean() reduction strategy is that your final reported losses for each batch are scaled to resemble a sample-wise loss. This makes them extremely easy to reason about.
Lets walk through an example using the
0-1 loss used in traditional machine learning. Imagine you see a final loss of
0.125. With this reported number it is extremely easy to figure out that you have misclassified
12.5% of examples in your batch. Great!
Compare this to the output of the
8. What is a loss of
8? Well, it means there are
8 misclassified samples in the batch. Is that good? Is that bad? Well, we have
64 samples in the batch, so
8/64 samples were misclassified. You can reach the same result, but imagine you are trying to tune your
BATCH_SIZE and reason about the results for each one. This quickly becomes extremely annoying.
In summary, due to the fact that
mean() reduced losses remain on the scale of sample-wise losses,
mean() reduced losses are much easier to reason about than
sum() reduced. One point for
Learning Rate Tuning
Unfortunately, there is a con to using
sum() method allows you to tune learning rates agnostic to batch sizes, while the
mean() is not batch-size agnostic.
To showcase this concept, lets walk through the process of computing the update to a value using the
sum() reduction technique. Lets take a learning rate of
0.01, and say that our network is currently misclassifying
12.5% of examples, or
8 misclassified examples per batch. First, consider two batches of 64 examples, each yielding a loss of
In this example,
0.5 for each example where the sample is misclassified and
0 otherwise. So, our total loss for a batch is
8 as there are
8 misclassified examples out of the
64. Following this, the weight update
W_1 will be equal to
-8*0.5*0.01 for each batch, so your final weight update after the two batches will be
Now imagine doubling the batch size, so instead we are computing the loss for all
128 examples. Well, we simply
sum() the loss for each sample, so the loss is
16 for the mega batch, and
ΔW_1=-16*0.5*0.01=-0.04. Great, the result is unchanged!
Now, let's walk through the same sample, but using the
To maintain numerical equivalency, we use a learning rate of
0.01*64. This may seem strange, but hold tight - this will be clearer in a moment!
0.5 for each example where the sample is misclassified and
0 otherwise. However this time, the loss for the batch is actually
(8*1 + 56*0) / 64, or
0.125. Next, we compute
-.125*0.5*0.01*64, again yielding
ΔW_1=-0.04. So far so good...
Now lets doubling the batch size, again to
128. Even if you know you need to scale the loss by the batch size as you're using
mean() - this is an easy change to forget to make! I've seen it cause issues in more code bases than I can count.
Due to the
mean() reduction, our loss for the mega batch is
(2*8*1 + 2*56*0) / 2*64, or again
.125. Our weight update is therefore identical, again leading to
So what is the issue with this? Well, you have effectively halved your learning rate. What a huge waste of resources! You'll need to double your epoch count to achieve the same convergence. Even worse, if you reduce your batch size you will be doubling your learning rate! A learning rate of double will very likely yield to significantly worse final performance. While this is an issue you can work around by scaling your learning rate to your batch size, it is a huge pain.
As you can imagine, this can be very annoying when tuning BOTH the learning rate and batch size of your training script! One point for
The Perfect Solution: My Recommendation
My recommendation is to use the
mean() strategy and a learning rate inferred from the
BATCH_SIZE of your training script. This allows you to get the best of both worlds: your reported loss will remain sample-wise, and your script will maintain batch-size agnostic.
You can achieve this with the following code:
learning_rate = BATCH_SIZE * some_constant_learning_rate
It's usually a bit easier to follow this learning rate inference process if you write your constant learning rate as an expression of the original batch size you tuned your learning rate with respect to, for example:
learning_rate = BATCH_SIZE * 0.01 / 16
In this example it is clear that the batch size will be 0.01 when
BATCH_SIZE=16. This pattern effectively allows you to tune the learning rate with respect to samples, while maintaining per-sample loss reporting!
Congratulations! You now know everything you need to about the tradeoffs of
reduction=MEAN, as well as the perfect solution:
learning_rate = BATCH_SIZE * your_tuned_lr / your_original_batch_size loss_obj = tf.keras.losses.CategoricalCrossentropy( # in Keras mean reduction is called `SUM_OVER_BATCH_SIZE` reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE )