A Common Pitfall When Writing Subclass Models in Keras

September 23, 2022
← Back to all posts

Overview

The other day while benchmarking the KerasCV RetinaNet I was shocked to see that the prediction and regression heads were both not training. The model achieved mediocre scores; but the weights were not being updated. After almost a week of debugging I found the culprit.

TLDR: subclass models cannot track the weights used in closure-style components; which are a commonly used abstraction in the keras_cv.models subpackage.

Demonstration

Lets demonstrate this edge case:

This is a properly written subclass model. You can fit it using:

Cool, it works. Thats all there is to it. Lets re-write this using the closure pattern to abstract the dense layers into a single block.

Great, and lets use it in our model:

Notice something? The model doesn't learn.

Why? Lets dig deeper:

There lies the culprit: the model doesn't know about the variables in DenseBlock(). This is because Keras cannot check the internals of the assigned function definition and recognize that it represents a layer.

To fix this just construct a custom layer, functional model, or sequential represnting the same structure:

Or:

Thats all, now our custom model will work again:

Conclusion

The takeaway is that subclass models can't track variables used in closure style functional layers.

Unless you really know Keras inside and out, just use tf.keras.layers.Layer when building a sub-class model. Keep closure style functions to the functional model building API.