Published on

A Common Pitfall When Writing Subclass Models in Keras

Authors

Overview

The other day while benchmarking the KerasCV RetinaNet I was shocked to see that the prediction and regression heads were both not training. The model achieved mediocre scores; but the weights were not being updated. After almost a week of debugging I found the culprit.

TLDR: subclass models cannot track the weights used in closure-style components; which are a commonly used abstraction in the keras_cv.models subpackage.

Demonstration

Lets demonstrate this edge case:

import tensorflow as tf
from tensorflow import keras
import numpy as np


class CustomModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense_1 = keras.layers.Dense(10)
        self.dense_2 = keras.layers.Dense(1)

    def call(self, x, training=None):
        x = self.dense_1(x)
        return self.dense_2(x)

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        return {m.name: m.result() for m in self.metrics}

This is a properly written subclass model. You can fit it using:

model = CustomModel()
model.compile(optimizer="adam", loss="mse")

# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1832
Epoch 2/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1476
Epoch 3/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1306
Epoch 4/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1150
Epoch 5/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1073

<keras.callbacks.History at 0x12e7dc670>

Cool, it works. Thats all there is to it. Lets re-write this using the closure pattern to abstract the dense layers into a single block.


def DenseBlock():
    dense_1 = keras.layers.Dense(10)
    dense_2 = keras.layers.Dense(1)

    def apply(x, training=None):
        x = dense_1(x)
        return dense_2(x)

    return apply

Great, and lets use it in our model:


class CustomModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense_block = DenseBlock()

    def call(self, x, training=None):
        return self.dense_block(x)

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)  # Broken
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        return {m.name: m.result() for m in self.metrics}


model = CustomModel()
model.compile(optimizer="adam", loss="mse")

# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)

Epoch 1/5
32/32 [==============================] - 0s 974us/step - loss: 0.5985
Epoch 2/5
32/32 [==============================] - 0s 1ms/step - loss: 0.5985
Epoch 3/5
32/32 [==============================] - 0s 983us/step - loss: 0.5985
Epoch 4/5
32/32 [==============================] - 0s 1ms/step - loss: 0.5985
Epoch 5/5
32/32 [==============================] - 0s 1ms/step - loss: 0.5985

<keras.callbacks.History at 0x12eee7cd0>

Notice something? The model doesn't learn.

Why? Lets dig deeper:

print(model.trainable_variables)
[]

There lies the culprit: the model doesn't know about the variables in DenseBlock(). This is because Keras cannot check the internals of the assigned function definition and recognize that it represents a layer.

To fix this just construct a custom layer, functional model, or sequential represnting the same structure:


def DenseBlock():
    return keras.Sequential([Dense(10), Dense(1)])

Or:


class DenseBlock(keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense_1 = keras.layers.Dense(10)
        self.dense_2 = keras.layers.Dense(1)

    def call(self, x, training=None):
        x = self.dense_1(x)
        return self.dense_2(x)

Thats all, now our custom model will work again:


class CustomModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense_block = DenseBlock()

    def call(self, x, training=None):
        return self.dense_block(x)

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        return {m.name: m.result() for m in self.metrics}


model = CustomModel()
model.compile(optimizer="adam", loss="mse")

# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
32/32 [==============================] - 0s 1ms/step - loss: 0.3773
Epoch 2/5
32/32 [==============================] - 0s 1ms/step - loss: 0.2919
Epoch 3/5
32/32 [==============================] - 0s 1ms/step - loss: 0.2321
Epoch 4/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1915
Epoch 5/5
32/32 [==============================] - 0s 1ms/step - loss: 0.1594

<keras.callbacks.History at 0x12ef8c3d0>

Conclusion

The takeaway is that subclass models can't track variables used in closure style functional layers.

Unless you really know Keras inside and out, just use tf.keras.layers.Layer when building a sub-class model. Keep closure style functions to the functional model building API.