Training a Cassava Classifier with KerasCV's CutMix and MixUp Layers



Recently my full time work has shifted towards working on KerasCV, the official home to computer vision extensions for Keras. Over the next few months KerasCV will be populated with reusable Keras components focused on computer vision such as layers, losses, preprocessing layers, and model architectures.

Today I published the first feature to KerasCV: CutMix & MixUp preprocessing layers. CutMix and MixUp are two state of the art preprocessing techniques. You can read more about them in CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features and MixUp: Beyond Empirical Risk Minimization respectively. These preprocessing layers are a critical component in training state of the art image classification models.

This quick guide shows how to use these layers to train a ResNet50V2 to classify Cassava plant leaves based on their condition as healthy or afflicted with one of the following diseases: cmd, cgm, cbsd, cbb. Cassava plants are the second largest provider of carbohydrates in Africa, with at least 80% of small farms growing the plant. Disease is a common cause of poor yields, and as such identifying diseased plants automatically is invaluable.


First, we need to install the keras-cv package. The package can currently be installed from Github with the following commands:

!git clone
!cd keras-cv && pip install . -q

Setting up a Pipeline

Let's setup a data augmentation pipeline. Our pipeline will crop the images and perform augmentation using CutMix and MixUp.

import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import applications
from keras_cv.layers.preprocessing import cut_mix
from keras_cv.layers.preprocessing import mix_up
data, ds_info = tfds.load("cassava", with_info=True, as_supervised=True)
num_classes = ds_info.features["label"].num_classes
{'test': <PrefetchDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>,
 'train': <PrefetchDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>,
 'validation': <PrefetchDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>}

First, let's define a function that will be used on both the training and eval datasets. This function will standardize the image sizes and one hot encode the image labels.

crop = layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE)
def preprocess(image, label):
  image = tf.cast(image, tf.float32)/255.0
  return crop(image), tf.one_hot(label, num_classes)

Next, we can create an instance of the CutMix and MixUp layers provided by KerasCV. Both of these data augmentation methods modify labels, so be sure to pass the labels to the call method.

Additionally, both layers perform label smoothing. In this example, we perform label smoothing in the mixup layer, as we perform that augmentation last.

CutMix and MixUp both require one positional argument, rate. This argument controls the rate at which the augmentation is applied. An optional argument label_smoothing can also be provided. It is recommended to utilize label smoothing alongside both of these layers, preferably whichever layer comes last in your augmentation pipeline.

cutmix = cut_mix.CutMix(0.35,)
mixup = mix_up.MixUp(0.35, label_smoothing=0.1)

Next we put our preprocessing pipeline together. Both cutmix and mixup take images and labels as arguments, and both return updated images and labels.

flip = layers.RandomFlip(mode='horizontal')
rotation = layers.RandomRotation(0.1)
zoom = layers.RandomZoom(0.1)

def augment_data_for_training(images, labels):
  images = flip(images)
  images = rotation(images)
  images = zoom(images)
  images, labels = cutmix(images, labels)
  images, labels = mixup(images, labels)
  return images, labels

Next, we need to create our pipeline for our train dataset.

ds_train = data["train"]
ds_train = ds_train.repeat()
ds_train = ds_train.shuffle(BATCH_SIZE * 10)
ds_train =, num_parallel_calls=AUTO)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train =, num_parallel_calls=AUTO)
ds_train = ds_train.prefetch(AUTO)

and our test dataset:

ds_test = data["test"]
ds_test = ds_test.repeat()
ds_test = ds_test.shuffle(BATCH_SIZE * 10)
ds_test =, num_parallel_calls=AUTO)
ds_test = ds_test.batch(BATCH_SIZE)
ds_test = ds_test.prefetch(AUTO)

Let's visualize some of the data:

import matplotlib.pyplot as plt
plt.figure(figsize=(8, 8))
for i, (images, labels) in enumerate(ds_train.take(9)):
  plt.subplot(3, 3, i + 1)

Looks great! It's easy to tell which examples are augmented with MixUp, which by CutMix, and which by both.

Model Creation

Finally, let's create our classification model.

In this example, we use a ResNet50V2 backbone using weights pretrained on imagenet. Our training loop uses Categorical Crossentropy as our loss function and the Adam optimizer for training. For evaluation, we use accuracy and top 5 accuracy.

input_shape = (IMAGE_SIZE, IMAGE_SIZE, 3)
def get_model():
  model = applications.ResNet50V2(input_shape=input_shape, classes=num_classes, weights=None)
  return model
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
  model = get_model()

Finally, we can train our model:, steps_per_epoch=200, epochs=10, validation_data=ds_test, validation_steps=50)
Epoch 1/10
200/200 [==============================] - 68s 241ms/step - loss: 1.4083 - accuracy: 0.4838 - val_loss: 1.3327 - val_accuracy: 0.4859
Epoch 2/10
200/200 [==============================] - 44s 219ms/step - loss: 1.3667 - accuracy: 0.5107 - val_loss: 1.4209 - val_accuracy: 0.3833
Epoch 3/10
200/200 [==============================] - 45s 224ms/step - loss: 1.3202 - accuracy: 0.5373 - val_loss: 1.2850 - val_accuracy: 0.4830
Epoch 4/10
200/200 [==============================] - 44s 222ms/step - loss: 1.3134 - accuracy: 0.5428 - val_loss: 1.1832 - val_accuracy: 0.5517
Epoch 5/10
200/200 [==============================] - 44s 221ms/step - loss: 1.2960 - accuracy: 0.5497 - val_loss: 1.2083 - val_accuracy: 0.5745
Epoch 6/10
200/200 [==============================] - 45s 224ms/step - loss: 1.2914 - accuracy: 0.5557 - val_loss: 1.4382 - val_accuracy: 0.4841
Epoch 7/10
200/200 [==============================] - 45s 224ms/step - loss: 1.2665 - accuracy: 0.5639 - val_loss: 1.1275 - val_accuracy: 0.5741
Epoch 8/10
200/200 [==============================] - 44s 223ms/step - loss: 1.2621 - accuracy: 0.5671 - val_loss: 1.1782 - val_accuracy: 0.5811
Epoch 9/10
200/200 [==============================] - 44s 221ms/step - loss: 1.2276 - accuracy: 0.5853 - val_loss: 1.1077 - val_accuracy: 0.5830
Epoch 10/10
200/200 [==============================] - 45s 226ms/step - loss: 1.2160 - accuracy: 0.5892 - val_loss: 1.2181 - val_accuracy: 0.5180

<keras.callbacks.History at 0x7f09fb04fd50>

Model Inference

Let's test our model. The model has a 60% accuracy with it's current implementation.

import numpy as np
label_titles = ['cmd', 'healthy', 'cgm', 'cbsd', 'cbb']
for images, label in ds_test.take(1):
  plt.figure(figsize=(8, 8))
  pred_labels = np.argmax(model.predict(images), axis=-1)
  true_labels = np.argmax(label.numpy(), axis=-1)
  plt.suptitle("True Label/Predicted Label")
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    ax.title.set_text(label_titles[true_labels[i]] + '/' + label_titles[pred_labels[i]])

Not bad! The model has learned to correctly classify the majority of the plants.

Conclusion & Next Steps

With KerasCV using CutMix and MixUp only requires a few extra lines of code.

The goal of KerasCV is to make it state of the art computer vision techniques accessible and easy to use directly with Keras. In the near future I aim to be able to train state of the art imagenet classifiers using only built in components of KerasCV.

As a follow up exercise to this experiment, play with the model architecture, preprocessing configuration, and optimizer used. Much higher scores are possible on the Cassava dataset, try to achieve >90% top 1 accuracy! One effective approach to scoring better on the Cassava dataset is to use contrastive learning on the unlabelled data corpus as a pretraining approach.

Please let me know if you use my CutMix or MixUp layers, and keep an eye on the repo for the official release of the repo!


Other Posts