SVM MNIST handwritten digit classification
SVM MNIST digit classification in python using scikit-learn
Project presents well known problem of MNIST handwritten digit classification. For the puropose of this tutorial I will use Support Vector Machine (SVM) algorithm with raw pixel features. Solution is written in python with use of scikit-learn easy to use machine learning library.
The goal of this project is not to achieve the state of the art performance, rather to teach you how to train SVM classifier on image data with use of SVM from sklearn. Althoug the sollution isn't optimized for high accuracy, the results are quite good (see table below).
If you want to hit the top performance, this two resources will show you current state of the art sollutions:
Table below shows some results in comparison with other models:
|Simple one-layer neural network||0.926|
|Simple 2 layer convolutional network||0.981|
|SVM RBF||0.9852||C=5, gamma=0.05|
|Linear SVM + Nystroem kernel approximation|
|Linear SVM + Fourier kernel approximation|
In this tutorial I use two approches for SVM learning. First, uses classical SVM with RBF kernel. The drawback of this solution is rather long training on big datasets, although the accuracy with good parameters is high. The second, uses Linear SVM, which allows for training in O(n) time. In order to achieve high accuracy we use some trick. We aproximate RBF kernel in a high dimensional space by embedings. The teory behind is quite complicated, however sklearn has ready to use clases for kernel approximation. We will use:
- Nystroem kernel approximation
- Fourier kernel approximation
The code was tested with python 2.7 and python 3.5.
How the project is organised
Project consist of three files:
- _mnisthelpers.py - contains some visualization functions: MNIST digits visualization and confusion matrix
- _svm_mnistclassification.py - script for SVM with RBF kernel classification
- _svm_mnistembedings.py - script for linear SVM with embedings
SVM with RBF kernel
The svm_mnist_classification.py script downloads the MNIST database and visualize some random digits. Next, it standarize the data (mean=0, std=1) and lauchn grid search with cross validation for finding the best parameters.
- MNIST SVM kernel RBF Param search C=[0.1,0.5,1,5], gamma=[0.01,0.0.05,0.1,0.5].
Grid search was done for params C and gamma, where C=[0.1,0.5,1,5], gamma=[0.01,0.0.05,0.1,0.5]. I have examine only 4x4 different param pairs with 3 fold cross validation so far (4x4x3=48 models), this procedure takes 3687.2min :) (2 days, 13:56:42.531223 exacly) on one core CPU.
Param space was generated with numpy logspace and outer matrix multiplication.
C_range = np.outer(np.logspace(-1, 0, 2),np.array([1,5])) # flatten matrix, change to 1D numpy array C_range = C_range.flatten() gamma_range = np.outer(np.logspace(-2, -1, 2),np.array([1,5])) gamma_range = gamma_range.flatten()
Of course, you can broaden the range of parameters, but this will increase the computation time.
Grid search is very time consuming process, so you can use my best parameters (from the range c=[0.1,5], gamma=[0.01,0.05]):
- C = 5
- gamma = 0.05
- accuracy = 0.9852
Confusion matrix: [[1014 0 2 0 0 2 2 0 1 3] [ 0 1177 2 1 1 0 1 0 2 1] [ 2 2 1037 2 0 0 0 2 5 1] [ 0 0 3 1035 0 5 0 6 6 2] [ 0 0 1 0 957 0 1 2 0 3] [ 1 1 0 4 1 947 4 0 5 1] [ 2 0 1 0 2 0 1076 0 4 0] [ 1 1 8 1 1 0 0 1110 2 4] [ 0 4 2 4 1 6 0 1 1018 1] [ 3 1 0 7 5 2 0 4 9 974]] Accuracy=0.985238095238
- MNIST SVM kernel RBF Param search C=[0.1,0.5,1,5, 10, 50], gamma=[0.001, 0.005, 0.01,0.0.05,0.1,0.5].
This much broaden search 6x8 params with 3 fold cross validation gives 6x8x3=144 models, this procedure takes 13024.3min (9 days, 1:33:58.999782 exacly) on one core CPU.
- C = 5
- gamma = 0.05
- accuracy = 0.9852
Linear SVM with different embedings
Linear SVM's (SVM with linear kernels) have this advantages that there are many O(n) trainning algorithms. They are really fast in comparision with other nonlinear SVM (where most of them are O(n^2)). This technique is really useful if you want to train on big data.
Linear SVM algortihtms examples(papers and software):
- Stochastic gradient descent
- Averaged Stochastic gradient descent
- Stochastic Gradient Descent with BarzilaiâBorwein update step for SVM
- Primal SVM by Olivier Chappelle - there also exists Primal SVM in Python
Unfortunatelly, linear SVM isn't powerfull enough to classify data with accuraccy comparable to RBF SVM.
Learning SVM with RBF kernel could be time consuming. In order to be more expressive we try to aproximate nonlinear kernel, map vectors int higher dimensional space explicity and use fast linear SVM in this new space. This works extreamly well!
The script _svm_mnistembedings.py presents accuracy summary and training times for full RBF kernel, linear SVC, and linear SVC with two kernel aproximation Nystroem and Fourier.
- Augmenting the training set with artificial samples
- Using Randomized param search
Useful SVM MNIST learning materials
- MNIST handwritten digit recognition - author compares a accuracy of a few machine learning classification algorithms (Random Forest, Stochastic Gradient Descent, Support Vector Machine, Nearest Neighbors)
- Digit Recognition using OpenCV, sklearn and Python - this blog post presents using HOG features and a multiclass Linear SVM.
- Grid search for RBF SVM parameters
- Fast and Accurate Digit Classification- technical report - there is also download page with custom LibLinear intersection kernel
- Random features for large-scale kernel machines Rahimi, A. and Recht, B. - Advances in neural information processing 2007,
- Efficient additive kernels via explicit feature maps Vedaldi, A. and Zisserman, A. - Computer Vision and Pattern Recognition 2010
- Generalized RBF feature maps for Efficient Detection Vempati, S. and Vedaldi, A. and Zisserman, A. and Jawahar, CV - 2010
[2017-04-4 20:44] v0.3.2 : Update MNIST SVM RBF param space
[2017-03-25 7:04] v0.3.1 : Add section for linear classifiers (LibLinear, SGD, Pegasos)
[2017-03-23 11:52] v0.2.10 : Fix grid classifier, add matplotlib heatmap for svm pram space
[2017-03-22 21:48] v0.2.9 : Add image width
[2017-03-22 21:44] v0.2.8 : Add image class
[2017-03-22 21:41] v0.2.7 : Add mnist digit image
[2017-03-18 22:26] v0.2.6 : Update readme, fix some issues in gridsearch
[2017-03-13 14:06] v0.2.5 : Fix links to github and plon
[2017-03-13 13:45] v0.2.4 : Update project to scikit-learn 0.18, change the readme add more information about project
Add links to github
More from this author
- Linear regression models
- Introduction to matplotlib
- Tensorflow numbers recognition
- Numpy tutorial
- The World Bank eastern Europe GDP Analysis using Python Pandas and Seaborn
- Tensorflow MNIST convolutional neural networks
- TensorFlow Examples by aymericdamien
- Python Bokeh Intorduction
- Image convolution in python
- Primal SVM