A Common Pitfall When Writing Subclass Models in Keras
Overview
The other day while benchmarking the KerasCV RetinaNet I was shocked to see that the prediction and regression heads were both not training. The model achieved mediocre scores; but the weights were not being updated. After almost a week of debugging I found the culprit.
TLDR: subclass models cannot track the weights used in closure-style components; which
are a commonly used abstraction in the keras_cv.models
subpackage.
Demonstration
Lets demonstrate this edge case:
This is a properly written subclass model. You can fit it using:
Cool, it works. Thats all there is to it. Lets re-write this using the closure pattern to abstract the dense layers into a single block.
Great, and lets use it in our model:
Notice something? The model doesn't learn.
Why? Lets dig deeper:
There lies the culprit: the model doesn't know about the variables in DenseBlock()
.
This is because Keras cannot check the internals of the assigned function definition
and recognize that it represents a layer.
To fix this just construct a custom layer, functional model, or sequential represnting the same structure:
Or:
Thats all, now our custom model will work again:
Conclusion
The takeaway is that subclass models can't track variables used in closure style functional layers.
Unless you really know Keras inside and out, just use
tf.keras.layers.Layer
when building a sub-class model. Keep closure style functions
to the functional model building API.