A case for knowledge distillation


When trying to reduce the cost of one of our machine learning demos, our engineers had to find a way to reduce the complexity of their model without a drastic decrease in accuracy. The method that we ended up using to accomplish this task was knowledge distillation! While researching and trying to understand the concept, we thought that the coverage of knowledge distillation was composed of articles that either had not enough information to be useful or were far too detailed to be considered an introduction to the subject. This article was written in an attempt to bridge that gap, and we hope that it contains enough information that you come out with a strong understanding of how knowledge distillation works without having gone ‘too far into the weeds’.


When training a neural network, your priority is usually finding a model that maximizes your test accuracy. Often these experimental models will end up being quite large, and when it comes time to bring them into your machine learning workforce, they are too large or too slow to be viable.

But what can you do with this hefty model? You don't want to throw it away; its still very accurate, and you spent a lot of time training it. Weight pruning is one option, but it can take just as much time as retraining the entire model, and in the end, the accuracy will often be lower than the original. Different methods of model compression, such as weight quantization and weight sharing, can be done as well. The improvements achieved from model compression are often much smaller than other methods and should be saved until the very end of optimization if they are to be used at all. Are there any methods that require minimal effort from the engineer/researcher but get consistently good results?

Enter knowledge distillation! In essence, knowledge distillation can be used to train a new, smaller, model with the large one that you have. Or, if accuracy is your only concern, it can be used to retrain a more accurate version of the existing model. These new models are often referred to as Student Models, and the original model is called the Parent or Teacher Model. Training a student model usually increases the accuracy, rather than decreasing it like weight pruning, and can often take less training time than the original. Depending on the chosen architecture of the student, it can also be smaller and faster! For our example, though, we will just be retraining a model of the same type to show improvements in accuracy.

How it works

When training a classifier for a multi-class classification problem, you generally have a target output that is a one-hot vector: each class is a dimension of the vector, and there should be only one present. For example, if you had a dataset of images where each image could be one of `[Cat, Dog, Car]` the target for an image of a dog would look like:

Cat Dog Car

[ 0 1 0 ]

This method works equally well for problems where an image can belong to more than one label at a time, but for this example, we restrict it to one. Now you and I might be able to easily say whether or not something is a cat or a dog based on its ears or its size, a computer that hasn't been around animals as long might not. As a result, its output won't be 0's or 1's. Instead, it produces a probability that the image belongs to each class, which must all sum to `1.0` in our case. To continue the above example, the output vector for an image of a dog could be:

Cat Dog Car

[ 0.24 0.75 0.01 ]

From this, we can see that the model is much more likely to confuse a dog with a cat, indicating it found that a dog has more visual features in common with a cat than it does a car. The more unsure the previous model is of its prediction, the more information it conveys. That information is what drives knowledge distillation: the previous predictions are used as training data alongside the true labels to train a new model. We believe that this process allows a model to impart its knowledge of the similarities between different classes to the next model. Training is usually done by giving the new model a loss function that incorporates both the previous predictions and the ground truth labels:

    Total Loss = Prediction Loss + λ Ground Truth Loss

λ is a parameter we introduce that allows us to control which part of the loss function is more important to the model being trained. It works a lot like a models learning rate: allowing us to start training using the previous predictions to show the new model the similarities between different classes, and then gradually incorporate more of the Ground Truth Loss to correct any of the previous model's mistakes.

Figure 1: Improved training flow

Figure 1: Improved training flow


Model Optimization

Improving Accuracy

To show you how it works, let's train a small example neural network on the CIFAR-10. This model consists of 2 convolutional blocks with RELU activations and a dropout rate of 0.25. The code for the model can be found at this gist and the model looks like this:

Figure 2: Model Architecture

Figure 2: Model Architecture


I kept the model simple so that it can be trained quickly by anyone (about 10s per epoch on a laptop). However, because it's so small it won't be very accurate, but accuracy doesn't matter in this case outside of its use for comparing between models.

To compare the accuracies we train 10 models. Five models were trained for 20 epochs using the ground truth labels, these are our Baseline models. A different five were trained for 10 epochs, the Teacher Models, from which the training set predictions were gathered and used to train a new set of five models, the Student Models. In total, each model has the accumulated knowledge of 20 training epochs and has been trained for the same amount of time. This should give a fair comparison between a traditionally trained model, and one trained with knowledge distillation. The results are summarized in the table below:

Model Baseline Students
1 0.6833 0.714
2 0.6971 0.712
3 0.6975 0.699
4 0.7124 0.715
5 0.6939 0.717
avg 0.6968 0.713

Looking at the accuracy table, knowledge distillation consistently improves accuracy while adding no extra training time, even using such a small model. When using larger models, like those found in Automated Speech Recognition (ASR), you can see even larger boosts in accuracy; nearly on par with an ensemble of models. For an example of using Knowledge Distillation to improve the accuracy of large models, I highly recommend the original paper by Hinton et al. In it, you can see how they used it to improve their ASR model and a more detailed explanation of classic knowledge distillation.

Reducing Model Size

But what if accuracy wasn't your only concern? In the introduction, I mentioned that the models developed during the experimentation phase are usually not the ones that end up in production, usually due to their size or complexity. Knowledge distillation allows you to distill a large model into a much smaller student. For example, the authors of A Gift from Knowledge Distillation: Fast Optimization, Network Minimization, and Transfer Learning used knowledge distillation to train a smaller version of a large convolutional model on the CIFAR-10 dataset. Not only were they able to reduce the model's size by a factor of 2.5x, but they were also able to increase the test accuracy of the model while only using half the number of training iterations!

Combining Models

Knowledge distillation isn't limited to direct transfer between the same network either; there are many examples of it being used to take many models of different architectures and combine them into one. For example, the winners of the 2nd Kaggle YouTube8M Video Understanding Challenge used knowledge distillation for their solution.

The challenge had strict space requirements for the winning model: it had to be less than 1GB when the weight file was uncompressed. This meant that participants had to consider each portion of their model carefully, making sure they were not using up any unnecessary space. In their first place solution, `▶Next top GB model` used knowledge distillation to combine models with NetVLAD, Deep Bag of Frames, FV Net and RNN architectures into one final model. You can find their full solution summary here: 1st place solution summary. So even in the case where you have a variety in model architectures, you can still gain improvements in overall accuracy using knowledge distillation.

Final Word

We've seen through examples, papers and a real competition how powerful a tool knowledge distillation can be. Maybe on your next project try splitting up your training time, half for the original model and half for knowledge distillation. Just make sure that next time you need to train a small model that will end up in production, or you're struggling to eke out the last bit of accuracy from your chosen model, you give knowledge distillation a second look. It really won't let you down.