Main Content

Optimize Classifier Fit Using Bayesian Optimization

This example shows how to optimize an SVM classification using the fitcsvm function and the OptimizeHyperparameters name-value argument.

Generate Data

The classification works on locations of points from a Gaussian mixture model. In The Elements of Statistical Learning, Hastie, Tibshirani, and Friedman (2009), page 17 describes the model. The model begins with generating 10 base points for a "green" class, distributed as 2-D independent normals with mean (1,0) and unit variance. It also generates 10 base points for a "red" class, distributed as 2-D independent normals with mean (0,1) and unit variance. For each class (green and red), generate 100 random points as follows:

  1. Choose a base point m of the appropriate color uniformly at random.

  2. Generate an independent random point with 2-D normal distribution with mean m and variance I/5, where I is the 2-by-2 identity matrix. In this example, use a variance I/50 to show the advantage of optimization more clearly.

Generate the 10 base points for each class.

rng('default') % For reproducibility
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

View the base points.

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Figure contains an axes object. The axes object contains 2 objects of type line. One or more of the lines displays its values using only markers

Since some red base points are close to green base points, it can be difficult to classify the data points based on location alone.

Generate the 100 data points of each class.

redpts = zeros(100,2);
grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

View the data points.

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Figure contains an axes object. The axes object contains 2 objects of type line. One or more of the lines displays its values using only markers

Prepare Data for Classification

Put the data into one matrix, and make a vector grp that labels the class of each point. 1 indicates the green class, and –1 indicates the red class.

cdata = [grnpts;redpts];
grp = ones(200,1);
grp(101:200) = -1;

Prepare Cross-Validation

Set up a partition for cross-validation.

c = cvpartition(200,'KFold',10);

This step is optional. If you specify a partition for the optimization, then you can compute an actual cross-validation loss for the returned model.

Optimize Fit

To find a good fit, meaning one with optimal hyperparameters that minimize the cross-validation loss, use Bayesian optimization. Specify a list of hyperparameters to optimize by using the OptimizeHyperparameters name-value argument, and specify optimization options by using the HyperparameterOptimizationOptions name-value argument.

Specify 'OptimizeHyperparameters' as 'auto'. The 'auto' option includes a typical set of hyperparameters to optimize. fitcsvm finds optimal values of BoxConstraint, KernelScale, and Standardize. Set the hyperparameter optimization options to use the cross-validation partition c and to choose the 'expected-improvement-plus' acquisition function for reproducibility. The default acquisition function depends on run time and, therefore, can give varying results.

opts = struct('CVPartition',c,'AcquisitionFunctionName', ...
    'expected-improvement-plus');
Mdl = fitcsvm(cdata,grp,'KernelFunction','rbf', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions',opts)
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |  Standardize |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |              |
|====================================================================================================================|
|    1 | Best   |       0.195 |     0.21831 |       0.195 |       0.195 |       193.54 |     0.069073 |        false |
|    2 | Accept |       0.345 |     0.10288 |       0.195 |     0.20398 |       43.991 |       277.86 |        false |
|    3 | Accept |       0.365 |    0.085025 |       0.195 |     0.20784 |    0.0056595 |     0.042141 |        false |
|    4 | Accept |        0.61 |     0.17863 |       0.195 |     0.31714 |       49.333 |    0.0010514 |         true |
|    5 | Best   |         0.1 |     0.30419 |         0.1 |     0.10005 |       996.27 |       1.3081 |        false |
|    6 | Accept |        0.13 |    0.069174 |         0.1 |     0.10003 |       25.398 |       1.7076 |        false |
|    7 | Best   |       0.085 |      0.1168 |       0.085 |     0.08521 |        930.3 |      0.66262 |        false |
|    8 | Accept |        0.35 |    0.066595 |       0.085 |    0.085172 |     0.012972 |        983.4 |         true |
|    9 | Best   |       0.075 |    0.091629 |       0.075 |    0.077959 |       871.26 |      0.40617 |        false |
|   10 | Accept |        0.08 |     0.12545 |       0.075 |    0.077975 |       974.28 |      0.45314 |        false |
|   11 | Accept |       0.235 |     0.30216 |       0.075 |    0.077907 |       920.57 |        6.482 |         true |
|   12 | Accept |       0.305 |    0.070665 |       0.075 |    0.077922 |    0.0010077 |       1.0212 |         true |
|   13 | Best   |        0.07 |    0.080775 |        0.07 |    0.073603 |       991.16 |      0.37801 |        false |
|   14 | Accept |       0.075 |    0.078256 |        0.07 |    0.073191 |       989.88 |      0.24951 |        false |
|   15 | Accept |       0.245 |     0.09407 |        0.07 |    0.073276 |       988.76 |       9.1309 |        false |
|   16 | Accept |        0.07 |      0.0795 |        0.07 |    0.071416 |       957.65 |      0.31271 |        false |
|   17 | Accept |        0.35 |     0.11798 |        0.07 |    0.071421 |    0.0010579 |       33.692 |         true |
|   18 | Accept |       0.085 |     0.05857 |        0.07 |    0.071274 |       48.536 |      0.32107 |        false |
|   19 | Accept |        0.07 |    0.082979 |        0.07 |    0.070587 |       742.56 |      0.30798 |        false |
|   20 | Accept |        0.61 |     0.19356 |        0.07 |    0.070796 |       865.48 |    0.0010165 |        false |
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |  Standardize |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |              |
|====================================================================================================================|
|   21 | Accept |         0.1 |    0.085428 |        0.07 |    0.070715 |       970.87 |      0.14635 |         true |
|   22 | Accept |       0.095 |     0.12121 |        0.07 |     0.07087 |       914.88 |      0.46353 |         true |
|   23 | Accept |        0.07 |     0.14119 |        0.07 |    0.070473 |       982.01 |       0.2792 |        false |
|   24 | Accept |        0.51 |     0.51006 |        0.07 |    0.070515 |    0.0010005 |     0.014749 |         true |
|   25 | Accept |       0.345 |     0.16526 |        0.07 |    0.070533 |    0.0010063 |       972.18 |        false |
|   26 | Accept |       0.315 |     0.17117 |        0.07 |     0.07057 |       947.71 |       152.95 |         true |
|   27 | Accept |        0.35 |     0.36783 |        0.07 |    0.070605 |    0.0010028 |        43.62 |        false |
|   28 | Accept |        0.61 |     0.10346 |        0.07 |    0.070598 |    0.0010405 |    0.0010258 |        false |
|   29 | Accept |       0.555 |     0.07333 |        0.07 |    0.070173 |       993.56 |     0.010502 |         true |
|   30 | Accept |        0.07 |    0.099019 |        0.07 |    0.070158 |       965.73 |      0.25363 |         true |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 16.8267 seconds
Total objective function evaluation time: 4.3552

Best observed feasible point:
    BoxConstraint    KernelScale    Standardize
    _____________    ___________    ___________

       991.16          0.37801         false   

Observed objective function value = 0.07
Estimated objective function value = 0.072292
Function evaluation time = 0.080775

Best estimated feasible point (according to models):
    BoxConstraint    KernelScale    Standardize
    _____________    ___________    ___________

       957.65          0.31271         false   

Estimated objective function value = 0.070158
Estimated function evaluation time = 0.092681

Figure contains an axes object. The axes object with title Min objective vs. Number of function evaluations, xlabel Function evaluations, ylabel Min objective contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

Mdl = 
  ClassificationSVM
                         ResponseName: 'Y'
                CategoricalPredictors: []
                           ClassNames: [-1 1]
                       ScoreTransform: 'none'
                      NumObservations: 200
    HyperparameterOptimizationResults: [1x1 BayesianOptimization]
                                Alpha: [66x1 double]
                                 Bias: -0.0910
                     KernelParameters: [1x1 struct]
                       BoxConstraints: [200x1 double]
                      ConvergenceInfo: [1x1 struct]
                      IsSupportVector: [200x1 logical]
                               Solver: 'SMO'


fitcsvm returns a ClassificationSVM model object that uses the best estimated feasible point. The best estimated feasible point is the set of hyperparameters that minimizes the upper confidence bound of the cross-validation loss based on the underlying Gaussian process model of the Bayesian optimization process.

The Bayesian optimization process internally maintains a Gaussian process model of the objective function. The objective function is the cross-validated misclassification rate for classification. For each iteration, the optimization process updates the Gaussian process model and uses the model to find a new set of hyperparameters. Each line of the iterative display shows the new set of hyperparameters and these column values:

  • Objective — Objective function value computed at the new set of hyperparameters.

  • Objective runtime — Objective function evaluation time.

  • Eval result — Result report, specified as Accept, Best, or Error. Accept indicates that the objective function returns a finite value, and Error indicates that the objective function returns a value that is not a finite real scalar. Best indicates that the objective function returns a finite value that is lower than previously computed objective function values.

  • BestSoFar(observed) — The minimum objective function value computed so far. This value is either the objective function value of the current iteration (if the Eval result value for the current iteration is Best) or the value of the previous Best iteration.

  • BestSoFar(estim.) — At each iteration, the software estimates the upper confidence bounds of the objective function values, using the updated Gaussian process model, at all the sets of hyperparameters tried so far. Then the software chooses the point with the minimum upper confidence bound. The BestSoFar(estim.) value is the objective function value returned by the predictObjective function at the minimum point.

The plot below the iterative display shows the BestSoFar(observed) and BestSoFar(estim.) values in blue and green, respectively.

The returned object Mdl uses the best estimated feasible point, that is, the set of hyperparameters that produces the BestSoFar(estim.) value in the final iteration based on the final Gaussian process model.

You can obtain the best point from the HyperparameterOptimizationResults property or by using the bestPoint function.

Mdl.HyperparameterOptimizationResults.XAtMinEstimatedObjective
ans=1×3 table
    BoxConstraint    KernelScale    Standardize
    _____________    ___________    ___________

       957.65          0.31271         false   

[x,CriterionValue,iteration] = bestPoint(Mdl.HyperparameterOptimizationResults)
x=1×3 table
    BoxConstraint    KernelScale    Standardize
    _____________    ___________    ___________

       957.65          0.31271         false   

CriterionValue = 
0.0724
iteration = 
16

By default, the bestPoint function uses the 'min-visited-upper-confidence-interval' criterion. This criterion chooses the hyperparameters obtained from the 16th iteration as the best point. CriterionValue is the upper bound of the cross-validated loss computed by the final Gaussian process model. Compute the actual cross-validated loss by using the partition c.

L_MinEstimated = kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c, ...
    'KernelFunction','rbf','BoxConstraint',x.BoxConstraint, ...
    'KernelScale',x.KernelScale,'Standardize',x.Standardize=='true'))
L_MinEstimated = 
0.0700

The actual cross-validated loss is close to the estimated value. The Estimated objective function value is displayed below the plot of the optimization results.

You can also extract the best observed feasible point (that is, the last Best point in the iterative display) from the HyperparameterOptimizationResults property or by specifying Criterion as 'min-observed'.

Mdl.HyperparameterOptimizationResults.XAtMinObjective
ans=1×3 table
    BoxConstraint    KernelScale    Standardize
    _____________    ___________    ___________

       991.16          0.37801         false   

[x_observed,CriterionValue_observed,iteration_observed] = ...
    bestPoint(Mdl.HyperparameterOptimizationResults,'Criterion','min-observed')
x_observed=1×3 table
    BoxConstraint    KernelScale    Standardize
    _____________    ___________    ___________

       991.16          0.37801         false   

CriterionValue_observed = 
0.0700
iteration_observed = 
13

The 'min-observed' criterion chooses the hyperparameters obtained from the 13th iteration as the best point. CriterionValue_observed is the actual cross-validated loss computed using the selected hyperparameters. For more information, see the Criterion name-value argument of bestPoint.

Visualize the optimized classifier.

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)), ...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(Mdl,xGrid);

figure
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(Mdl.IsSupportVector,1), ...
    cdata(Mdl.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');

Figure contains an axes object. The axes object contains 4 objects of type line, contour. One or more of the lines displays its values using only markers These objects represent -1, +1, Support Vectors.

Evaluate Accuracy on New Data

Generate and classify new test data points.

grnobj = gmdistribution(grnpop,.2*eye(2));
redobj = gmdistribution(redpop,.2*eye(2));

newData = random(grnobj,10);
newData = [newData;random(redobj,10)];
grpData = ones(20,1); % green = 1
grpData(11:20) = -1; % red = -1

v = predict(Mdl,newData);

Compute the misclassification rates on the test data set.

L_Test = loss(Mdl,newData,grpData)
L_Test = 
0.2000

Determine which new data points are classified correctly. Format the correctly classified points in red squares and the incorrectly classified points in black squares.

h(4:5) = gscatter(newData(:,1),newData(:,2),v,'mc','**');

mydiff = (v == grpData); % Classified correctly

for ii = mydiff % Plot red squares around correct pts
    h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12);
end

for ii = not(mydiff) % Plot black squares around incorrect pts
    h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12);
end
legend(h,{'-1 (training)','+1 (training)','Support Vectors', ...
    '-1 (classified)','+1 (classified)', ...
    'Correctly Classified','Misclassified'}, ...
    'Location','Southeast');
hold off

Figure contains an axes object. The axes object contains 8 objects of type line, contour. One or more of the lines displays its values using only markers These objects represent -1 (training), +1 (training), Support Vectors, -1 (classified), +1 (classified), Correctly Classified, Misclassified.

See Also

|

Related Topics