Data scientists and developers can now easily perform incremental learning on Amazon SageMaker. Incremental learning is a machine learning (ML) technique for extending the knowledge of an existing model by training it further on new data. Starting today both of the Amazon SageMaker built-in visual recognition algorithms – Image Classification and Object Detection – will provide out of the box support for incremental learning. So now you can easily load an existing Amazon SageMaker visual recognition model using the AWS Management Console or Amazon SageMaker Python SDK APIs, prior to starting the model training on new data.
Overview
Incremental learning is the technique of continuously extending the knowledge of an existing machine learning model by training it further on new data. So at the beginning of a training run, you first load the model weights from a prior training run instead of randomly initializing them, and then continue training the model on new data. In this way you preserve the knowledge that the model gained from prior training runs and extend it further. This is useful when you don’t have access to all of the training data at the same time and your data arrives continuously in batches over time. You can also use this learning technique to save some time and compute resources when re-training your model on new training data.
In this blog post we’ll also demonstrate how to use Amazon SageMaker incremental learning features to perform transfer learning. For this demonstration we’ll use an existing model off the shelf. We’ll choose an image classification model from a model zoo, and then use it as a starting point to train the model for performing a new classification task. Transfer learning enables building new models on top of state-of-the-art reference implementations for specific machine learning tasks. This is also useful when you don’t have enough data to train a deep and complex network from scratch.
Now let’s dive into the examples.
Incrementally train visual recognition models using Amazon SageMaker built-in algorithms
We have provided sample notebooks for both of the Amazon SageMaker built-in visual recognition algorithms – Image Classification, and Object Detection – that now support incremental learning. Following are the code snippets from the Image Classification notebook. If you are training an Amazon SageMaker Image Classification model for the first time, the notebook has step-by-step instructions for it. In this example we are assuming you already have an existing Image Classification model that was trained before on Amazon SageMaker.
Step 1: Define an input channel for consuming the existing Amazon SageMaker Image Classification model.
An Amazon SageMaker channel is a named input data source that training algorithms can consume. This input channel has to be named “model” and it specifies the Amazon S3 URI of the existing model. Note that the existing model artifacts is a single gzip compressed tar archive (.tar.gz suffix) created by Amazon SageMaker Training.
1 |
|
Step 2: Now continue training on new batch of training data.The hyperparameters that define the network, such as num_layers, image_shape, num_classes, etc., should be the same as those used for training the existing model. Since the algorithm starts with an existing, pre-trained model, the accuracy would be higher right from the first epoch, thereby leading to faster convergence.
1 |
|
You can repeat these steps as many time as you need to train your model further on new data.
Use a pre-trained Caffe model from ONNX model zoo to perform your image classification task
We’ll now show you an example of how to pick a model off the shelf, in this case a Caffe BVLC GoogleNet model that was trained using the ImageNet dataset and available on the ONNX Model Zoo. We’ll use this model as a starting point and then fine-tune it for a new image classification task on the Caltech 101 Dataset using Amazon SageMaker. We’re using the same model training script as shown in the MXNet/Gluon tutorial for transfer learning.
We’ll use the Amazon SageMaker MXNet framework container to train the model. Also note that this example uses the Amazon SageMaker Python SDK , similar to our existing Gluon notebooks.
Step 1: Download the pre-trained GoogleNet model from the ONNX model zoo and upload the model.onnx file to Amazon S3.
The ONNX model zoo hosts pre-trained models in Amazon S3 buckets in the us-east-1 AWS Region. You can use the Amazon S3 URI of pre-trained model as-it-is. However, if you are using Amazon SageMaker training in a different AWS Region (such as us-west-2), here is sample code for moving the file across Regions.
1 |
|
Step 2: Define Amazon SageMaker channels for the input data – one for the Caltech 101 training dataset and another for the pre-trained GoogleNet model.
In this example we define a ‘training’ channel for the Caltech 101 training dataset, and a ‘pretrained’ channel for the pre-trained GoogleNet model (from Step 1).
1 |
|
As you can see we are defining the input mode as ‘File’ at each channel level. File mode enables fetching the pre-trained model from Amazon S3 to local storage attached to the Amazon SageMaker training instances before the model training starts.
Now before we show you the code for starting Amazon SageMaker training using our pre-built MXNet container, we will first show you how you can make small, one-line code changes to the model training script from the Gluon tutorial for transfer learning for easily accessing your pre-trained GoogleNet model.
Step 3: Easily access the channel information inside the MXNet container using environment variables.
You can use the default environment variables of the MXNet container that are automatically initialized by Amazon SageMaker with all the information about the input channels you defined in Step 2.
1 |
|
Now you are ready to call the call the train
function in the model training script, passing it the Caltech 101 training dataset and pre-trained GoogleNet model.
1 |
|
You can save this updated script as transfer_learning_example.py.
Following is a short code snippet from the train
function for illustration purposes. As you can see, the function loads the pre-trained GoogleNet model before tuning it further on Caltech 101 training dataset.
1 |
|
Step 4: Train the model on Amazon SageMaker using a pre-built MXNet container.
You are now ready to run the training script from Step 3 using a pre-built Amazon SageMaker MXNet container. We recommend using a GPU instance for faster training. In this example, we use a p3.2xlarge instance.
1 |
|
Step 5: Observe the improvement in training accuracy from the training logs.
Our training script prints out the untrained network accuracy on the new data set and the accuracy after fine-tuning on the new dataset.
1 |
|
As you can see, we were able to improve our accuracy on the Caltech 101 Dataset substantially with just few minutes of fine-tuning on a GPU!
Get started with more examples and developer support
In this blog post we showed you examples of how to easily perform incremental learning and transfer learning using input channels on Amazon SageMaker. You can refer our developer guide for more developer resources or post your questions on our developer forum. Happy modeling!
About the authors
Gurumurthy Swaminathan is a Senior Applied Scientist in the Amazon AI Platforms group and is working on building computer vision algorithms for Sagemaker. His current area of research includes Neural Network compression and Computer Vision algorithms.
Jeffrey Geevarghese is a Senior Engineer in Amazon AI where he’s passionate about building scalable infrastructure for deep learning. Prior to this he was working on machine learning algorithms and platforms and was part of the launch teams for both Amazon SageMaker and Amazon Machine Learning.
Sumit Thakur is a Senior Product Manager for AWS Machine Learning Platforms where he loves working on products that make it easy for customers to get started with machine learning on cloud. He is product manager for Amazon SageMaker and AWS Deep Learning AMI. In his spare time, he likes connecting with nature and watching sci-fi TV series.**