- Published on

# The Perfect Solution to Loss Reduction in Deep Learning!

- Authors
- Name
- Luke Wood
- @luke_wood_ml

# The Perfect Solution to Loss Reduction in Deep Learning!

In deep learning, the most common optimization algorithm used in training is mini-batch stochastic gradient descent algorithm. In this algorithm, a random sampling of data points from your training set is sampled and weight gradients are computed with respect to the entire batch.

When you compute the loss for a batch of samples. An example we will consider is the classification example where your mini-batch consists of `64`

images and their `64`

corresponding labels. When computing the loss, you initially will get a batch of losses, a vector of dimension `64`

in our example above.

When computing the gradient of each weight in your network with respect to this loss, you consider a scalar loss; not the `64`

dimensional loss vector described above. But what strategy do you employ to convert from the `64`

dimensional vector to the scalar? There are two obvious candidate choices: `sum()`

, and `mean()`

. These approaches are both capable of yielding identical numerical results, and may seem identical on the surface, but there is actually an important tradeoff to consider when picking your loss reduction strategy.

## Batch-level loss reporting

Let's start by discussing loss reporting.

The main benefit of using the `mean()`

reduction strategy is that your final reported losses for each batch are scaled to resemble a sample-wise loss. This makes them extremely easy to reason about.

Lets walk through an example using the `0-1`

loss used in traditional machine learning. Imagine you see a final loss of `0.125`

. With this reported number it is extremely easy to figure out that you have misclassified `12.5%`

of examples in your batch. Great!

Compare this to the output of the `sum()`

reduction: `8`

. What is a loss of `8`

? Well, it means there are `8`

misclassified samples in the batch. Is that good? Is that bad? Well, we have `64`

samples in the batch, so `8/64`

samples were misclassified. You can reach the same result, but imagine you are trying to tune your `BATCH_SIZE`

and reason about the results for each one. This quickly becomes extremely annoying.

In summary, due to the fact that `mean()`

reduced losses remain on the scale of sample-wise losses, `mean()`

reduced losses are much easier to reason about than `sum()`

reduced. One point for `mean()`

!

## Learning Rate Tuning

Unfortunately, there is a con to using `mean()`

. The `sum()`

method allows you to tune learning rates agnostic to batch sizes, while the `mean()`

is not batch-size agnostic.

To showcase this concept, lets walk through the process of computing the update to a value using the `sum()`

reduction technique. Lets take a learning rate of `0.01`

, and say that our network is currently misclassifying `12.5%`

of examples, or `8`

misclassified examples per batch. First, consider two batches of 64 examples, each yielding a loss of `8`

.

In this example, `dL/dW_1`

is `0.5`

for each example where the sample is misclassified and `0`

otherwise. So, our total loss for a batch is `8`

as there are `8`

misclassified examples out of the `64`

. Following this, the weight update `ΔW_1`

for `W_1`

will be equal to `-8*0.5*0.01`

for each batch, so your final weight update after the two batches will be `ΔW_1=-2*8*0.5*0.01=-0.04`

.

Now imagine doubling the batch size, so instead we are computing the loss for all `128`

examples. Well, we simply `sum()`

the loss for each sample, so the loss is `16`

for the mega batch, and `ΔW_1=-16*0.5*0.01=-0.04`

. Great, the result is unchanged!

Now, let's walk through the same sample, but using the `mean()`

strategy.

To maintain numerical equivalency, we use a learning rate of `0.01*64`

. This may seem strange, but hold tight - this will be clearer in a moment!

Again, `dL/dW_1`

is `0.5`

for each example where the sample is misclassified and `0`

otherwise. However this time, the loss for the batch is actually `(8*1 + 56*0) / 64`

, or `0.125`

. Next, we compute `ΔW_1`

, or `-.125*0.5*0.01*64`

, again yielding `ΔW_1=-0.04`

. So far so good...

Now lets doubling the batch size, again to `128`

. Even if you know you need to scale the loss by the batch size as you're using `mean()`

- this is an easy change to forget to make! I've seen it cause issues in more code bases than I can count.

Due to the `mean()`

reduction, our loss for the mega batch is `(2*8*1 + 2*56*0) / 2*64`

, or again `.125`

. Our weight update is therefore identical, again leading to `ΔW_1=-0.04`

.

So what is the issue with this? Well, you have effectively halved your learning rate. What a huge waste of resources! You'll need to double your epoch count to achieve the same convergence. Even worse, if you reduce your batch size you will be doubling your learning rate! A learning rate of double will very likely yield to significantly worse final performance. While this is an issue you can work around by scaling your learning rate to your batch size, it is a huge pain.

As you can imagine, this can be very annoying when tuning BOTH the learning rate and batch size of your training script! One point for `sum()`

!

## The Perfect Solution: My Recommendation

My recommendation is to use the `mean()`

strategy and a learning rate inferred from the `BATCH_SIZE`

of your training script. This allows you to get the best of both worlds: your reported loss will remain sample-wise, and your script will maintain batch-size agnostic.

You can achieve this with the following code:

```
learning_rate = BATCH_SIZE * some_constant_learning_rate
```

It's usually a bit easier to follow this learning rate inference process if you write your constant learning rate as an expression of the original batch size you tuned your learning rate with respect to, for example:

```
learning_rate = BATCH_SIZE * 0.01 / 16
```

In this example it is clear that the batch size will be 0.01 when `BATCH_SIZE=16`

. This pattern effectively allows you to tune the learning rate with respect to samples, while maintaining per-sample loss reporting!

Congratulations! You now know everything you need to about the tradeoffs of `reduction=SUM`

vs `reduction=MEAN`

, as well as the perfect solution:

```
learning_rate = BATCH_SIZE * your_tuned_lr / your_original_batch_size
loss_obj = tf.keras.losses.CategoricalCrossentropy(
# in Keras mean reduction is called `SUM_OVER_BATCH_SIZE`
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
)
```