Main Content

Parasite Classification Using Wavelet Scattering and Deep Learning

This example shows how to classify parasitic infections in Giemsa stain images using wavelet image scattering and deep learning. The dataset is challenging for deep networks because it contains only 48 images. The images are divided evenly into three categories of parasitic infections: babesiosis, plasmodium-gametocyte, and trypanosomiasis.

Data

Obtain the data from the MATLAB® file exchange: Deploying Deep Neural Networks to Embedded GPUs and unzip the file. The file is in the same folder as this example.

url = "https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/" + ...
    "5918495a-0009-419e-8e10-77b06e3fe553/844e43fa-7c50-4f88-a435-f0afe04fc3a3/" + ...
    "packages/zip";
websave("classifyBloodSmearImages.zip",url);
unzip('classifyBloodSmearImages.zip')

Create an ImageDatastore to manage the access of the Giemsa stain images. The images are in RGB format with a common size of 300-by-300-by-3.

imagedir = fullfile('classifyBloodSmearImages','BloodSmearImages');
Imds = imageDatastore(imagedir,'IncludeSubFolders',true,'FileExtensions',...
    '.jpg','LabelSource','foldernames');
summary(Imds.Labels)
     babesiosis                 16 
     plasmodium-gametocyte      16 
     trypanosomiasis            16 

There are 16 images for each of the three parasite types. Split the data into training and hold-out test sets, with 70 percent of the images in the training set and 30 percent in the test set. Set the random number generator for reproducibility.

rng default
[trainImds,testImds] = splitEachLabel(Imds,0.7);

Verify that equal numbers of each parasite class are contained in both the training and test sets.

summary(trainImds.Labels)
     babesiosis                 11 
     plasmodium-gametocyte      11 
     trypanosomiasis            11 
% Perform the same for the test set.
summary(testImds.Labels)
     babesiosis                 5 
     plasmodium-gametocyte      5 
     trypanosomiasis            5 

Because this is a small dataset, the entire training and test sets fit in memory. Read all images for both sets.

trainImages = readall(trainImds);
testImages = readall(testImds);

Plot some sample images from the training data.

idx = randperm(33,6);
figure
for ii = 1:length(idx)
    im = trainImages{idx(ii)};
    subplot(3,2,ii)
    imshow(im,[])
    title(string(trainImds.Labels(idx(ii))));
end

Wavelet Scattering Network

In this example, you use a wavelet scattering transform as the feature extractor for the machine learning approaches. The wavelet scattering transform helps to reduce the dimensionality of the data and increase the interclass dissimilarity. Construct a two-layer image scattering network with a 40-by-40 pixel invariance scale. Use two wavelets per octave in the first layer and one wavelet per octave in the second layer. Use two rotations of the wavelets per layer.

sn = waveletScattering2('ImageSize',[300 300],'InvarianceScale',40,...
    'QualityFactors',[2 1],'NumRotations',[2 2]);
[~,npaths] = paths(sn);
sum(npaths)
ans = 27
coefficientSize(sn)
ans = 1×2

    38    38

The specified wavelet scattering network has 27 paths. The image on each scattering path is reduced to 38-by-38-by-3. Even without further averaging of the scattering coefficients, this is a reduction in the size of each image's memory by more than a factor of 2. However, for classification we form a feature vector that averages the scattering coefficients over the spatial and channel dimensions. This results in feature vectors with only 27 elements, a real-valued scalar for each scattering path. This represents a reduction in the number of elements by a factor of 10,000 for each image.

The following code computes the wavelet scattering feature vectors for both the training and test sets. Concatenate the feature vectors so that you have N-by-27 matrices, where N is the number of examples in the training or test set and each row is a wavelet scattering feature vector for an example.

trainfeatures = cellfun(@(x)helperScatImages_mean(sn,x),trainImages,'Uni',0);
testfeatures = cellfun(@(x)helperScatImages_mean(sn,x),testImages,'Uni',0);
trainfeatures = cat(1,trainfeatures{:});
testfeatures = cat(1,testfeatures{:});

SVM Classification

Use an SVM classifier with the scattering features. Choose a cubic polynomial kernel. Use a one-vs-all coding scheme.

template = templateSVM(...
    'KernelFunction', 'polynomial', ...
    'PolynomialOrder', 3, ...
    'KernelScale', 1, ...
    'BoxConstraint', 314, ...
    'Standardize', true);
classificationSVM = fitcecoc(trainfeatures,trainImds.Labels,...
    'Learners', template, 'Coding', 'onevsall');

Estimate the accuracy on the training set using cross-validation with 5 folds.

kfoldmodel = crossval(classificationSVM, 'KFold', 5);
loss = kfoldLoss(kfoldmodel)*100;
crossvalAccuracy = 100-loss
crossvalAccuracy = single
    81.8182

The cross-validation accuracy is approximately 80 percent. Now examine the accuracy on the held-out test set and plot the confusion chart.

[predLabels,scores] = predict(classificationSVM,testfeatures);
testAccuracy = ...
    sum(categorical(predLabels)== testImds.Labels)/numel(testImds.Labels)*100
testAccuracy = 80
figure
cchart = confusionchart(testImds.Labels,predLabels);
cchart.Title = ...
    {'Confusion Chart for Wavelet' ; 'Scattering Features using SVM'};
cchart.RowSummary = 'row-normalized';
cchart.ColumnSummary = 'column-normalized';

The overall test accuracy is 80 percent with the SVM model. The recall for each class is 80%. The precision is also good for the plasmodium-gametocyte and trypanosomiasis parasites, but worse for babesiosis. Examine the F1 scores for each class.

f1SVM = f1score(cchart.NormalizedValues);
disp(f1SVM)
                               F1   
                             _______

    babesiosis               0.72727
    plasmodium-gametocyte    0.88889
    trypanosomiasis              0.8

All F1 scores are between approximately 0.7 and 0.9.

PCA classifier with scattering features

Support vector machines are powerful techniques for features that are not linearly separable, but they are designed for binary classification and may be suboptimal for multiclass problems. Here you complement the SVM analysis by using a simple PCA (linear) classifier with the same wavelet scattering features. The helperPCAModel function determines the numcomp eigenvectors corresponding to the largest eigenvalues of the covariance matrix of the wavelet scattering features for each pathogen in the training set along with the class means.

helperPCAClassifier classifies each test sample. It does this by subtracting the model class means from each wavelet scattering feature vector in the test dataset and projecting the centered feature vectors onto the covariance-matrix eigenvectors for each class in the model. helperPCAClassifier assigns each test example to the pathogen with the smallest error, or residual. This is a principal components analysis (PCA) classifier.

Remove the 0-th order scattering features from each feature vector. Set the number of principal components (eigenvectors) to 6.

numcomp = 6;
model = helperPCAModel(trainfeatures(:,2:end)',numcomp,trainImds.Labels);
PCALabels = helperPCAClassifier(testfeatures(:,2:end)',model);
testPCAacc = sum(PCALabels==testImds.Labels)/numel(testImds.Labels)*100
testPCAacc = 86.6667

The test accuracy is approximately 87% with the PCA classifier. Plot the confusion chart and calculate the F1 scores for each class.

figure
cchart = confusionchart(testImds.Labels,PCALabels);
cchart.Title = {'Confusion Chart for Wavelet Scattering Features' ; ...
    'using PCA Classifier'};
cchart.RowSummary = 'row-normalized';
cchart.ColumnSummary = 'column-normalized';

f1PCA = f1score(cchart.NormalizedValues);
disp(f1PCA)
                               F1   
                             _______

    babesiosis               0.90909
    plasmodium-gametocyte    0.88889
    trypanosomiasis              0.8

The F1 scores for the PCA classifier with wavelet scattering features are quite strong, with all scores between 0.8 and 1.

Convolutional Deep Network

In this section, you attempt the same classification using deep convolutional networks. Deep networks provide state-of-art results for classification problems with large datasets and are capable of learning complicated nonlinear mappings, but their performance often suffers in small datasets. To mitigate this problem, use an image augmenter. imageDataAugmenter perturbs the data in each epoch, in effect creating new training examples.

augmenter = imageDataAugmenter('RandRotation',[0 180],'RandXTranslation', [-5 5], ...
    'RandYTranslation',[-5 5]);
augimds = augmentedImageDatastore([300 300 3],trainImds,'DataAugmentation',augmenter);

Define a small CNN consisting of two convolution layers followed by batch normalization layers and RELU activations. Follow the final RELU activation with max pooling, fully connected, and softmax layers.

layers = [
    imageInputLayer([300 300 3])
    convolution2dLayer(7,16)
    batchNormalizationLayer
    reluLayer    
    convolution2dLayer(3,20)
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(4)
    fullyConnectedLayer(3)
    softmaxLayer
    classificationLayer];

Use stochastic gradient descent with a minibatch size of 10. Shuffle the data each epoch. Run the training for 100 epochs.

opts = trainingOptions('sgdm',...
    'InitialLearnRate',  0.0001, ...
    'MaxEpochs', 100, ...
    'MiniBatchSize',10,...
    'Shuffle','every-epoch',...
    'Plots', 'training-progress',...
    'Verbose',false,...
    'ExecutionEnvironment','cpu');

Train the network.

trainedNet = trainNetwork(augimds,layers,opts);

Examine the performance of the network on the held-out test set.

ypred = trainedNet.classify(testImds);
cnnAccuracy = sum(ypred == testImds.Labels)/numel(testImds.Labels)*100
cnnAccuracy = 66.6667
figure
cchart = confusionchart(testImds.Labels,ypred);
cchart.Title = 'Confusion Chart for Deep CNN';
cchart.RowSummary = 'row-normalized';
cchart.ColumnSummary = 'column-normalized';

f1CNN = f1score(cchart.NormalizedValues);
disp(f1CNN)
                               F1   
                             _______

    babesiosis                  0.75
    plasmodium-gametocyte    0.76923
    trypanosomiasis          0.44444

In spite of using an augmented dataset for training, the CNN has overfit the training set and the F1 scores are significantly worse than either the SVM or PCA model with the wavelet scattering features.

Next, use transfer learning with SqueezeNet. Modify the final convolutional layer to accommodate the fact that you have three classes of pathogens. SqueezeNet was constructed to recognize 1,000 classes.

net = squeezenet;
lgraphSQZ = layerGraph(net);
numClasses = numel(categories(trainImds.Labels));
oldFinalConv = lgraphSQZ.Layers(end-4);
newFinalConv = convolution2dLayer(1,numClasses, ...
        'Name','new_conv');
setLearnRateFactor(newFinalConv,'Weights',10);
setLearnRateFactor(newFinalConv,'Bias',10)
ans = 
  Convolution2DLayer with properties:

              Name: 'new_conv'

   Hyperparameters
        FilterSize: [1 1]
       NumChannels: 'auto'
        NumFilters: 3
            Stride: [1 1]
    DilationFactor: [1 1]
       PaddingMode: 'manual'
       PaddingSize: [0 0 0 0]
      PaddingValue: 0

   Learnable Parameters
           Weights: []
              Bias: []

  Show all properties

lgraphSQZ = replaceLayer(lgraphSQZ,oldFinalConv.Name,newFinalConv);
oldClassLayer= lgraphSQZ.Layers(end);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraphSQZ = replaceLayer(lgraphSQZ,oldClassLayer.Name,newClassLayer);

Reset the training and test datastores. Modify the datastore read function to resize images to be compatible with SqueezeNet, which expects 227-by-227-by-3 images. Set up the image augmenter and train the network.

reset(trainImds);
reset(testImds);
trainImds.ReadFcn = @(x)imresize(imread(x),'OutputSize',[227 227]);
testImds.ReadFcn = @(x)imresize(imread(x),'OutputSize',[227 227]);
augmenter = imageDataAugmenter('RandRotation',[0 180],'RandXTranslation', [-5 5], ...
    'RandYTranslation',[-5 5]);
augimds = augmentedImageDatastore([227 227 3],trainImds,...
    'DataAugmentation',augmenter);
trainedNet = trainNetwork(augimds,lgraphSQZ,opts);

Obtain the SqueezeNet accuracy, plot the confusion chart, and compute the F1 scores.

ypred = trainedNet.classify(testImds);
sqznetAccuracy = sum(ypred == testImds.Labels)/numel(testImds.Labels)*100
sqznetAccuracy = 73.3333
figure
cchart = confusionchart(testImds.Labels,ypred);
cchart.Title = {'Confusion Chart for Transfer Learning' ; 'with SqueezeNet'};
cchart.RowSummary = 'row-normalized';
cchart.ColumnSummary = 'column-normalized';

f1SqueezeNet = f1score(cchart.NormalizedValues);
disp(f1SqueezeNet)
                               F1   
                             _______

    babesiosis               0.72727
    plasmodium-gametocyte        0.8
    trypanosomiasis          0.66667

SqueezeNet performs better than the simpler CNN, particularly in terms of the F1 score for trypanosomiasis, but the performance does not match the accuracy of the simpler PCA classifier with the wavelet scattering features.

Summary

In this example, the wavelet scattering transform and deep learning frameworks were used to classify pathogens in Giemsa stain images. The limited dataset size provides challenges for training a deep learning classifier even when data augmentation is used. The example illustrated that the wavelet scattering transform can provide a useful alternative to deep networks in such cases. In forming feature vectors from the wavelet scattering transform, we reduced each transform output from a 27-by-38-by-38-by-3 tensor to a 27-element vector. Accordingly, we have used a global pooling of the scattering coefficients. It is possible to utilize other pooling schemes, which could yield better results.

Appendix — Supporting Functions

function features = helperScatImages_mean(sn,x)
smat = featureMatrix(sn,x);
features = mean(smat,2:4);
features = features';
end
function F1scores = f1score(cchartVal)
N = sum(cchartVal,'all');
probT = sum(cchartVal)./N;
classProbEst = diag(cchartVal)./N;
Prec = classProbEst'./probT;
probC = [5/15 5/15 5/15];
Recall = classProbEst'./probC;
F1scores = harmmean([Prec ; Recall]);
F1scores = F1scores';
F1scores = table(F1scores,'VariableNames',{'F1'},...
    'RowNames', {'babesiosis','plasmodium-gametocyte', 'trypanosomiasis'});
end

function labels = helperPCAClassifier(features,model)
% This function is only to support wavelet image scattering examples in 
% Wavelet Toolbox. It may change or be removed in a future release.
% model is a structure array with fields, M, mu, v, and Labels
% features is the matrix of test data which is Ns-by-L, Ns is the number of
% scattering paths and L is the number of test examples. Each column of
% features is a test example.

% Copyright 2018-2021 MathWorks

labelIdx = determineClass(features,model); 
labels = model.Labels(labelIdx); 
% Returns as column vector to agree with imageDatastore Labels
labels = labels(:);


%--------------------------------------------------------------------------
function labelIdx = determineClass(features,model)
% Determine number of classes
Nclasses = numel(model.Labels);
% Initialize error matrix
errMatrix = Inf(Nclasses,size(features,2));
for nc = 1:Nclasses
    % class centroid
    mu = model.mu{nc};
    u = model.U{nc};
    % 1-by-L
    errMatrix(nc,:) = projectionError(features,mu,u);
end
% Determine minimum along class dimension
[~,labelIdx] = min(errMatrix,[],1);   


%--------------------------------------------------------------------------
function totalerr = projectionError(features,mu,u)
    %
    Npc = size(u,2);
    L = size(features,2);
    % Subtract class mean: Ns-by-L minus Ns-by-1
    s = features-mu;
    % 1-by-L
    normSqX = sum(abs(s).^2,1)';
    err = Inf(Npc+1,L);
	err(1,:) = normSqX;
    err(2:end,:) = -abs(u'*s).^2;
    % 1-by-L
    totalerr = sqrt(sum(err,1));
end
end
end

function model = helperPCAModel(features,M,Labels)
% This function is only to support wavelet image scattering examples in
% Wavelet Toolbox. It may change or be removed in a future release.
% model = helperPCAModel(features,M,Labels)

% Copyright 2018-2021 MathWorks

% Initialize structure array to hold the affine model
model = struct('Dim',[],'mu',[],'U',[],'Labels',categorical([]),'S',[]);
model.Dim = M;
% Obtain the number of classes
LabelCategories = categories(Labels);
Nclasses = numel(categories(Labels));
for kk = 1:Nclasses
    Class = LabelCategories{kk};
    % Find indices corresponding to each class
    idxClass = Labels == Class;
    % Extract feature vectors for each class
    tmpFeatures = features(:,idxClass);
    % Determine the mean for each class
    model.mu{kk} = mean(tmpFeatures,2);
    [model.U{kk},model.S{kk}] = scatPCA(tmpFeatures);
    if size(model.U{kk},2) > M
        model.U{kk} = model.U{kk}(:,1:M);
        model.S{kk} = model.S{kk}(1:M);
        
    end
    model.Labels(kk) = Class;
end

    function [u,s,v] = scatPCA(x)
        % Calculate the principal components of x along the second dimension.
        [u,d] = eig(cov(x'));
        % Sort eigenvalues of covariance matrix in descending order
        [s,ind] = sort(diag(d),'descend');
        % sort eigenvector matrix accordingly
        u = u(:,ind);
    end
end

See Also

(Wavelet Toolbox)

Related Examples

More About