TensorFlow 2.0 Tutorial 04: Early Stopping

During training, weights in the neural networks are updated so that the model performs better on the training data. For a while, improvements on the training set correlate positively with improvements on the test set. However, there comes a point where you begin to overfit on the training data and further "improvements" will result in lower generalization performance. This is known as overfitting. Early stopping is a technique used to terminate the training before overfitting occurs.

This tutorial explains how early stopping is implemented in TensorFlow 2. All code for this tutorial is available in our repository.

Early stopping is implemented in TensorFlow via the tf.keras.EarlyStopping callback function:

earlystop_callback = EarlyStopping(
  monitor='val_accuracy', min_delta=0.0001,
  patience=1)

monitor keep track of the quantity that is used to decide if the training should be terminated. In this case, we use the validation accuracy. min_delta is the threshold that triggers the termination. In this case, we require that the accuracy should at least improve 0.0001. patience is the number of "no improvement epochs" to wait until training is stopped. With patience = 1, training terminates immediately after the first epoch with no improvement.

Now, we can attach the early stop callback and run training with early stopping:

model.fit(train_dataset,
          epochs=10, callbacks=[earlystop_callback],
          validation_data=test_dataset,
          validation_freq=1)

Epoch 1/10
390/390 [==============================] - 73s 187ms/step - loss: 2.7133 - accuracy: 0.3300 - val_loss: 6.3186 - val_accuracy: 0.1752
Epoch 2/10
390/390 [==============================] - 39s 100ms/step - loss: 2.2262 - accuracy: 0.4914 - val_loss: 2.5499 - val_accuracy: 0.4358
Epoch 3/10
390/390 [==============================] - 39s 100ms/step - loss: 1.9842 - accuracy: 0.5801 - val_loss: 2.5666 - val_accuracy: 0.4708
Epoch 4/10
390/390 [==============================] - 39s 99ms/step - loss: 1.8144 - accuracy: 0.6333 - val_loss: 2.2643 - val_accuracy: 0.5407
Epoch 5/10
390/390 [==============================] - 39s 99ms/step - loss: 1.6799 - accuracy: 0.6770 - val_loss: 2.1015 - val_accuracy: 0.5841
Epoch 6/10
390/390 [==============================] - 39s 99ms/step - loss: 1.5700 - accuracy: 0.7104 - val_loss: 2.0468 - val_accuracy: 0.6078
Epoch 7/10
390/390 [==============================] - 38s 98ms/step - loss: 1.4697 - accuracy: 0.7388 - val_loss: 2.0628 - val_accuracy: 0.5925
Epoch 00007: early stopping

Notice the 7th epoch resulted in better training accuracy but lower validation accuracy. Thus, the training terminated at the 7th epoch despite the fact that the maximum number of epochs is set to 10.

Summary

This tutorial explains how early stopping is implemented in TensorFlow 2. The key takeaway is to use the tf.keras.EarlyStopping callback. Early stopping is triggered by monitoring if a certain value (for example, validation accuracy) has improved over the latest period of time (controlled by the patience argument).

To reproduce these results, please refer to this code repo.