Published on

The Perfect Solution to Loss Reduction in Deep Learning!


The Perfect Solution to Loss Reduction in Deep Learning!

In deep learning, the most common optimization algorithm used in training is mini-batch stochastic gradient descent algorithm. In this algorithm, a random sampling of data points from your training set is sampled and weight gradients are computed with respect to the entire batch.

When you compute the loss for a batch of samples. An example we will consider is the classification example where your mini-batch consists of 64 images and their 64 corresponding labels. When computing the loss, you initially will get a batch of losses, a vector of dimension 64 in our example above.

When computing the gradient of each weight in your network with respect to this loss, you consider a scalar loss; not the 64 dimensional loss vector described above. But what strategy do you employ to convert from the 64 dimensional vector to the scalar? There are two obvious candidate choices: sum(), and mean(). These approaches are both capable of yielding identical numerical results, and may seem identical on the surface, but there is actually an important tradeoff to consider when picking your loss reduction strategy.

Batch-level loss reporting

Let's start by discussing loss reporting.

The main benefit of using the mean() reduction strategy is that your final reported losses for each batch are scaled to resemble a sample-wise loss. This makes them extremely easy to reason about.

Lets walk through an example using the 0-1 loss used in traditional machine learning. Imagine you see a final loss of 0.125. With this reported number it is extremely easy to figure out that you have misclassified 12.5% of examples in your batch. Great!

Compare this to the output of the sum() reduction: 8. What is a loss of 8? Well, it means there are 8 misclassified samples in the batch. Is that good? Is that bad? Well, we have 64 samples in the batch, so 8/64 samples were misclassified. You can reach the same result, but imagine you are trying to tune your BATCH_SIZE and reason about the results for each one. This quickly becomes extremely annoying.

In summary, due to the fact that mean() reduced losses remain on the scale of sample-wise losses, mean() reduced losses are much easier to reason about than sum() reduced. One point for mean()!

Learning Rate Tuning

Unfortunately, there is a con to using mean(). The sum() method allows you to tune learning rates agnostic to batch sizes, while the mean() is not batch-size agnostic.

To showcase this concept, lets walk through the process of computing the update to a value using the sum() reduction technique. Lets take a learning rate of 0.01, and say that our network is currently misclassifying 12.5% of examples, or 8 misclassified examples per batch. First, consider two batches of 64 examples, each yielding a loss of 8.

In this example, dL/dW_1 is 0.5 for each example where the sample is misclassified and 0 otherwise. So, our total loss for a batch is 8 as there are 8 misclassified examples out of the 64. Following this, the weight update ΔW_1 for W_1 will be equal to -8*0.5*0.01 for each batch, so your final weight update after the two batches will be ΔW_1=-2*8*0.5*0.01=-0.04.

Now imagine doubling the batch size, so instead we are computing the loss for all 128 examples. Well, we simply sum() the loss for each sample, so the loss is 16 for the mega batch, and ΔW_1=-16*0.5*0.01=-0.04. Great, the result is unchanged!

Now, let's walk through the same sample, but using the mean() strategy.

To maintain numerical equivalency, we use a learning rate of 0.01*64. This may seem strange, but hold tight - this will be clearer in a moment!

Again, dL/dW_1 is 0.5 for each example where the sample is misclassified and 0 otherwise. However this time, the loss for the batch is actually (8*1 + 56*0) / 64, or 0.125. Next, we compute ΔW_1, or -.125*0.5*0.01*64, again yielding ΔW_1=-0.04. So far so good...

Now lets doubling the batch size, again to 128. Even if you know you need to scale the loss by the batch size as you're using mean() - this is an easy change to forget to make! I've seen it cause issues in more code bases than I can count.

Due to the mean() reduction, our loss for the mega batch is (2*8*1 + 2*56*0) / 2*64, or again .125. Our weight update is therefore identical, again leading to ΔW_1=-0.04.

So what is the issue with this? Well, you have effectively halved your learning rate. What a huge waste of resources! You'll need to double your epoch count to achieve the same convergence. Even worse, if you reduce your batch size you will be doubling your learning rate! A learning rate of double will very likely yield to significantly worse final performance. While this is an issue you can work around by scaling your learning rate to your batch size, it is a huge pain.

As you can imagine, this can be very annoying when tuning BOTH the learning rate and batch size of your training script! One point for sum()!

The Perfect Solution: My Recommendation

My recommendation is to use the mean() strategy and a learning rate inferred from the BATCH_SIZE of your training script. This allows you to get the best of both worlds: your reported loss will remain sample-wise, and your script will maintain batch-size agnostic.

You can achieve this with the following code:

learning_rate = BATCH_SIZE * some_constant_learning_rate

It's usually a bit easier to follow this learning rate inference process if you write your constant learning rate as an expression of the original batch size you tuned your learning rate with respect to, for example:

learning_rate = BATCH_SIZE * 0.01 / 16

In this example it is clear that the batch size will be 0.01 when BATCH_SIZE=16. This pattern effectively allows you to tune the learning rate with respect to samples, while maintaining per-sample loss reporting!

Congratulations! You now know everything you need to about the tradeoffs of reduction=SUM vs reduction=MEAN, as well as the perfect solution:

learning_rate = BATCH_SIZE * your_tuned_lr / your_original_batch_size
loss_obj = tf.keras.losses.CategoricalCrossentropy(
  # in Keras mean reduction is called `SUM_OVER_BATCH_SIZE`