- Published on
A Common Pitfall When Writing Subclass Models in Keras
- Authors
- Name
- Luke Wood
- @luke_wood_ml
Overview
The other day while benchmarking the KerasCV RetinaNet I was shocked to see that the prediction and regression heads were both not training. The model achieved mediocre scores; but the weights were not being updated. After almost a week of debugging I found the culprit.
TLDR: subclass models cannot track the weights used in closure-style components; which are a commonly used abstraction in the keras_cv.models
subpackage.
Demonstration
Lets demonstrate this edge case:
import tensorflow as tf
from tensorflow import keras
import numpy as np
class CustomModel(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense_1 = keras.layers.Dense(10)
self.dense_2 = keras.layers.Dense(1)
def call(self, x, training=None):
x = self.dense_1(x)
return self.dense_2(x)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
return {m.name: m.result() for m in self.metrics}
This is a properly written subclass model. You can fit it using:
model = CustomModel()
model.compile(optimizer="adam", loss="mse")
# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1832
Epoch 2/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1476
Epoch 3/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1306
Epoch 4/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1150
Epoch 5/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1073
<keras.callbacks.History at 0x12e7dc670>
Cool, it works. Thats all there is to it. Lets re-write this using the closure pattern to abstract the dense layers into a single block.
def DenseBlock():
dense_1 = keras.layers.Dense(10)
dense_2 = keras.layers.Dense(1)
def apply(x, training=None):
x = dense_1(x)
return dense_2(x)
return apply
Great, and lets use it in our model:
class CustomModel(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense_block = DenseBlock()
def call(self, x, training=None):
return self.dense_block(x)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars) # Broken
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
return {m.name: m.result() for m in self.metrics}
model = CustomModel()
model.compile(optimizer="adam", loss="mse")
# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
32/32 [==============================] - 0s 974us/step - loss: 0.5985
Epoch 2/5
32/32 [==============================] - 0s 1ms/step - loss: 0.5985
Epoch 3/5
32/32 [==============================] - 0s 983us/step - loss: 0.5985
Epoch 4/5
32/32 [==============================] - 0s 1ms/step - loss: 0.5985
Epoch 5/5
32/32 [==============================] - 0s 1ms/step - loss: 0.5985
<keras.callbacks.History at 0x12eee7cd0>
Notice something? The model doesn't learn.
Why? Lets dig deeper:
print(model.trainable_variables)
[]
There lies the culprit: the model doesn't know about the variables in DenseBlock()
. This is because Keras cannot check the internals of the assigned function definition and recognize that it represents a layer.
To fix this just construct a custom layer, functional model, or sequential represnting the same structure:
def DenseBlock():
return keras.Sequential([Dense(10), Dense(1)])
Or:
class DenseBlock(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense_1 = keras.layers.Dense(10)
self.dense_2 = keras.layers.Dense(1)
def call(self, x, training=None):
x = self.dense_1(x)
return self.dense_2(x)
Thats all, now our custom model will work again:
class CustomModel(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense_block = DenseBlock()
def call(self, x, training=None):
return self.dense_block(x)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
return {m.name: m.result() for m in self.metrics}
model = CustomModel()
model.compile(optimizer="adam", loss="mse")
# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
32/32 [==============================] - 0s 1ms/step - loss: 0.3773
Epoch 2/5
32/32 [==============================] - 0s 1ms/step - loss: 0.2919
Epoch 3/5
32/32 [==============================] - 0s 1ms/step - loss: 0.2321
Epoch 4/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1915
Epoch 5/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1594
<keras.callbacks.History at 0x12ef8c3d0>
Conclusion
The takeaway is that subclass models can't track variables used in closure style functional layers.
Unless you really know Keras inside and out, just use tf.keras.layers.Layer
when building a sub-class model. Keep closure style functions to the functional model building API.