Skip to main content

Deep Learning: When To Stop Training Nueral Network?

When To Stop Training In Deep Learning?

Quick Recap

An artificial neural network is a combination of artificial neurons which does some math and try to estimate a mathematical function. This estimation process is called training or fitting.

Basic training mechanism

The math involved in ANN is mostly MAC (Multiply-Accumulate) operations where the input is multiplied by weights and biases are added to the product. One of the activation functions is applied to the output and it is forwarded to the next layer and the same process continues until it reaches to the end layer. This process is called feed-forward.

After the end layer calculation, the output computed by the network is compared with the actual output. The difference between actual output and estimated output is calculated using a function called loss function. Common loss functions these days are Mean Squared Error, Mean Absolute Error, Root MSE, Cross-Entropy etc.
The error calculated using loss function is propagated backward throughout the neural network in the form of gradients. This phase is called back propagation. In this phase, all the network parameters (weights and biases) are updated in order to minimize the error using another function called optimizer. There are different optimizer functions like SGD, Adam, Adadelta, Nadam, RMSprop etc.

One feed-forward and one back-propagation makes a single iteration. After each iteration, the model move towards better estimation. The number of iterations is the ratio of total data samples over batch size. For example, if a dataset contains 10,000 samples and it is fed to a network in 100 samples per batch then the entire dataset is processed in 10,000/100 = 100 iterations. This makes a single epoch.

What is overfitting and underfitting? 

After each epoch, the model improves and estimates better output. But the problem is that too much training leads to overfitting. An overfit model performs best only on the data it has been trained on and gives its worst performance for unseen data. It means the training needs to be stopped before the model overfits the data. However, it is just one end of the problem. The other end is if training is stopped too early, the model underfits. In underfitting conditions, a trained model gives average or worst performance on both seen and unseen data.

So when do we stop training?

To avoid overfitting and underfitting, we need an optimal value for network error or loss where the model neither overfits nor underfits. But there is no such theory which can recommend a perfect point. So how can we know when to stop training?
Fortunately, there are different algorithms designed to define the stopping point. And another good news is that all the famous deep learning libraries provide functions for these algorithms. As a fan and a user of TensorFlow, I will talk only about one of the functions provided by TensorFlow called EarlyStopping. It is a callback method which needs to be hooked with train/fit function. It is available in tf.keras.callbacks module.

What it actually does is it monitors the neural network model while training and gives a stop command when the model no longer improves. Simple but read it again.

To better understand this, lets assume that you are studying in a school which has strict monitoring and evaluation policy (I know it is a horrible scene but bear with me).  Your performance is evaulated in your school at the end of each term. They keep you in the school as long as you are learning and improving. But as soon as you stop learning, they give you a warning and ask you to restore to your best learning mode. You try again for the next term and fail to improve. They give you another warning and let you keep doing the school things. You try again and harder this time but fail to learn effectively and they kick your ass out of the school this time.

Now think about yourself as the artificial neural network in the above story and your school as the legendary EarlyStopping callback.

EarlyStopping gives you the ability to design your own custom convergence criteria by providing a comprehensive list of parameters (see the list below).

monitor (string): this parameter specifies the metric which will be used as a base for making the holly decision (to stop the training). Remember how they evaluated you in the school story? Yah, it is the grads, curriculum and behavior thing. This parameter can have any of the two options: 'val_loss' and 'val_acc'. The first refers to the loss value of your network after an epoch while the latter refers to the accuracy of the model.

min_delta (float): it is the minimum change in the monitored quantity to be considered as an improvement e.g if you assigned val_acc to monitor and 0.01 to min_delta and your model just finished 3 epochs, it simply means that if accuracy of the model after 3 epochs is greater than the accuracy after 2 epochs by at least 0.01, it would be considered as improvement otherwise, no improvement.

patience (integer): patience is similar to the number of warnings given to a student in the school story. It specifies the number of epochs you want the EarlyStopping callback to wait for the model to improve. For example if it is given a value of 3, the callback will ignore the no improvement state of the model for 3 epochs.

verbose (integer): I bet you know this. No explanation.

mode (string): One of {"auto", "min", "max"}. In min mode, training will stop when the quantity monitored has stopped decreasing; in max mode it will stop when the quantity monitored has stopped increasing; in auto mode, the direction is automatically inferred from the name of the monitored quantity. i.e. for 'val_loss', default value of mode is 'min' and for 'val_acc' default value is 'max'.

baseline (integer): Baseline value for the monitored quantity. Training will stop if the model doesn't show improvement over the baseline.

restore_best_weights (bool): Whether to restore model weights from the epoch with the best value of the monitored quantity. If False, the model weights obtained at the last step of training are used.

sample python code at github

Comments