Skip to main content

Deep Learning: When To Stop Training Nueral Network?

When To Stop Training In Deep Learning?

Quick Recap

An artificial neural network is a combination of artificial neurons which does some math and try to estimate a mathematical function. This estimation process is called training or fitting.

Basic training mechanism

The math involved in ANN is mostly MAC (Multiply-Accumulate) operations where the input is multiplied by weights and biases are added to the product. One of the activation functions is applied to the output and it is forwarded to the next layer and the same process continues until it reaches to the end layer. This process is called feed-forward.

After the end layer calculation, the output computed by the network is compared with the actual output. The difference between actual output and estimated output is calculated using a function called loss function. Common loss functions these days are Mean Squared Error, Mean Absolute Error, Root MSE, Cross-Entropy etc.
The error calculated using loss function is propagated backward throughout the neural network in the form of gradients. This phase is called back propagation. In this phase, all the network parameters (weights and biases) are updated in order to minimize the error using another function called optimizer. There are different optimizer functions like SGD, Adam, Adadelta, Nadam, RMSprop etc.

One feed-forward and one back-propagation makes a single iteration. After each iteration, the model move towards better estimation. The number of iterations is the ratio of total data samples over batch size. For example, if a dataset contains 10,000 samples and it is fed to a network in 100 samples per batch then the entire dataset is processed in 10,000/100 = 100 iterations. This makes a single epoch.

What is overfitting and underfitting? 

After each epoch, the model improves and estimates better output. But the problem is that too much training leads to overfitting. An overfit model performs best only on the data it has been trained on and gives its worst performance for unseen data. It means the training needs to be stopped before the model overfits the data. However, it is just one end of the problem. The other end is if training is stopped too early, the model underfits. In underfitting conditions, a trained model gives average or worst performance on both seen and unseen data.

So when do we stop training?

To avoid overfitting and underfitting, we need an optimal value for network error or loss where the model neither overfits nor underfits. But there is no such theory which can recommend a perfect point. So how can we know when to stop training?
Fortunately, there are different algorithms designed to define the stopping point. And another good news is that all the famous deep learning libraries provide functions for these algorithms. As a fan and a user of TensorFlow, I will talk only about one of the functions provided by TensorFlow called EarlyStopping. It is a callback method which needs to be hooked with train/fit function. It is available in tf.keras.callbacks module.

What it actually does is it monitors the neural network model while training and gives a stop command when the model no longer improves. Simple but read it again.

To better understand this, lets assume that you are studying in a school which has strict monitoring and evaluation policy (I know it is a horrible scene but bear with me).  Your performance is evaulated in your school at the end of each term. They keep you in the school as long as you are learning and improving. But as soon as you stop learning, they give you a warning and ask you to restore to your best learning mode. You try again for the next term and fail to improve. They give you another warning and let you keep doing the school things. You try again and harder this time but fail to learn effectively and they kick your ass out of the school this time.

Now think about yourself as the artificial neural network in the above story and your school as the legendary EarlyStopping callback.

EarlyStopping gives you the ability to design your own custom convergence criteria by providing a comprehensive list of parameters (see the list below).

monitor (string): this parameter specifies the metric which will be used as a base for making the holly decision (to stop the training). Remember how they evaluated you in the school story? Yah, it is the grads, curriculum and behavior thing. This parameter can have any of the two options: 'val_loss' and 'val_acc'. The first refers to the loss value of your network after an epoch while the latter refers to the accuracy of the model.

min_delta (float): it is the minimum change in the monitored quantity to be considered as an improvement e.g if you assigned val_acc to monitor and 0.01 to min_delta and your model just finished 3 epochs, it simply means that if accuracy of the model after 3 epochs is greater than the accuracy after 2 epochs by at least 0.01, it would be considered as improvement otherwise, no improvement.

patience (integer): patience is similar to the number of warnings given to a student in the school story. It specifies the number of epochs you want the EarlyStopping callback to wait for the model to improve. For example if it is given a value of 3, the callback will ignore the no improvement state of the model for 3 epochs.

verbose (integer): I bet you know this. No explanation.

mode (string): One of {"auto", "min", "max"}. In min mode, training will stop when the quantity monitored has stopped decreasing; in max mode it will stop when the quantity monitored has stopped increasing; in auto mode, the direction is automatically inferred from the name of the monitored quantity. i.e. for 'val_loss', default value of mode is 'min' and for 'val_acc' default value is 'max'.

baseline (integer): Baseline value for the monitored quantity. Training will stop if the model doesn't show improvement over the baseline.

restore_best_weights (bool): Whether to restore model weights from the epoch with the best value of the monitored quantity. If False, the model weights obtained at the last step of training are used.

sample python code at github


Popular posts from this blog

How Big Data Analytics Can Help You Improve And Grow Your Business?

Big Data Analytics There are certain problems that can only solve through big data. Here we discuss the field big data as "Big Data Analytics". The big data came into the picture we never thought how commodity hardware is used to store and manage the data which is reliable and feasible as compared to the costly sources. Now let us discuss a few examples of how big data analytics is useful nowadays. When you go to websites like Amazon, Youtube, Netflix, and any other websites actually they will provide some field in which recommend some product, videos, movies, and some songs for you. What do you think about how they do it? Basically what kind of data they generated on these kind websites. They make sure to analyze properly. The data generated is not small it is actually big data. Now they analysis these big data they make sure whatever you like and whatever you are the preferences accordingly they generate recommendations for you. If you go to Youtube you have noticed it kn…

How Computers Understand Human Language?

How Computers Understand Human Language? Natural languages are the languages that we speak and understand, containing large diverse vocabulary. Various words have several different meanings, speakers with different accents and all sorts of interesting word play. But for the most part human can roll right through these challenges. The skillful use of language is a major part what makes us human and for this reason the desire for computers that understand or speak human language has been around since they were first conceived. This led to the creation of natural language processing or NLP.
Natural Language Processing is a disciplinary field combining computer science and linguistics. There is an infinite number of ways to arrange words in a sentence. We can't give computers a dictionary of all possible sentences to help them understand what humans are blabbing on about. So, an early and fundamental NLP problem was deconstructing sentences into small pieces which could be more easily…

The Limits of Artificial Intelligence

If you are here, it means that you are familiar with term artificial intelligence. Either you have read about it in school or have seen it in sci-fi movies or somewhere else. Talking about the limitations of AI, let me ask you one simple question first, do you know the definition of AI? You might be thinking to answer me with a yes, yes I know what is artificial intelligence. But what if I tell you that AI is a buzzword and it is almost impossible to properly define. It is this way because the definition of artificial intelligence is moving. People don’t call the things AI that they used to call. For example, a problem that seemed too complex to be solved by human and was solved by AI algorithm is no longer a problem of AI. Playing chess, is one of the examples. It was considered the peek level of artificial intelligence back in previous century. Now it hardly fits the criteria for AI. It is presented to the world as a super power that when given to a computer, it magically starts li…