Main Content

Prune and Quantize Convolutional Neural Network for Speech Recognition

This example shows how to compress a convolutional neural nework (CNN) to prepare it for deployment on an embedded system.

Deploying deep learning models on embedded systems can be challenging due to the limited memory and processing power of embedded systems. Model compression addresses these limitations by reducing the memory footprint of a model.

This example covers two model compression techniques for deep learning models: Taylor pruning and quantization. Taylor pruning reduces the size of the network and increases inference speed by removing convolution filters. The memory requirement of the network is then reduced further by quantizing the weights, biases, and activations of layers to 8-bit scaled integer data types.

The network you use in this example is trained to recognize speech commands in Train Speech Command Recognition Model Using Deep Learning.

Load Data

This example uses the Google Speech Commands Dataset [1]. Download and unzip the data set.

downloadFolder = matlab.internal.examples.downloadSupportFile("audio","google_speech.zip");
dataFolder = tempdir;
unzip(downloadFolder,dataFolder)
dataset = fullfile(dataFolder,"google_speech");

Create Training and Validation Data

Create training and validation datastores before loading the pretrained network. This section follows the steps in Train Speech Command Recognition Model Using Deep Learning to augment the data, create datastores and extract auditory spectrograms.

Augment Data

The network must be able to not only recognize different spoken words but also to detect if the audio input is silence or background noise.

The supporting function augmentDataset uses the long audio files in the background folder of the Google Speech Commands Dataset to create one-second segments of background noise. The function creates an equal number of background segments from each background noise file and then splits the segments between the training and validation folders.

augmentDataset(dataset)

Create Training and Validation Datastores

Use the supporting function createDatastores to create training and validation datastores. The function accepts a categorical array specifying the words that you want your model to recognize as commands and returns training and validation datastores, adsTrain and adsValidation.

commands = categorical(["yes","no","up","down","left","right","on","off","stop","go"]);
[adsTrain,adsValidation] = createDatastores(dataset,commands);

Extract Features

Use the supporting function extractFeatures to extract the auditory spectrograms from the audio input. XTrain contains the spectrograms from the training datastore, and XValidation contains the spectrograms from the validation datastore. TTrain and TValidation are the training and validation target labels, isolated for convenience.

[XTrain,XValidation,TTrain,TValidation] = extractFeatures(adsTrain,adsValidation);

Load Pretrained Network

Load the trained network from Train Speech Command Recognition Model Using Deep Learning.

load("trainedCommandNet.mat")

Evaluate Trained Network

Calculate the network accuracy prior to model compression. To calculate the final accuracy of the network on the training and validation sets, use classify.

YValidation = classify(trainedNet,XValidation);
validationAccuracy = mean(YValidation == TValidation);
YTrain = classify(trainedNet,XTrain);
trainAccuracy = mean(YTrain == TTrain);

disp(["Training Accuracy: " + trainAccuracy*100 + "%";"Validation Accuracy: " + validationAccuracy*100 + "%"])
    "Training Accuracy: 96.3918%"
    "Validation Accuracy: 94.2779%"

To plot the confusion matrix for the validation set, use confusionchart. Display the precision and recall for each class by using column and row summaries.

figure(Units="normalized",Position=[0.2,0.2,0.5,0.5]);
cm = confusionchart(TValidation,YValidation, ...
    Title="Original Network", ...
    ColumnSummary="column-normalized",RowSummary="row-normalized");
sortClasses(cm,[commands,"unknown","background"])

Prepare Network and Data for Pruning

Convert the network and data into a format that is compatible with the pruning workflow. This process involves creating datastores and mini-batches of data for training and validation data sets and converting the network to a dlnetwork object.

Create datastores dsTrain and dsValidation from the spectrograms used for network training and validation.

classes = categories(TTrain);

classes = categories(TTrain);
dsTrain = augmentedImageDatastore([98 50], XTrain, TTrain);
dsValidation = augmentedImageDatastore([98 50], XValidation, TValidation);

Create minibatchqueue objects for the training and validation data for use in the custom pruning loops.

miniBatchSize = 50;

executionEnvironment = "auto";

mbqTrain = minibatchqueue(dsTrain, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@preprocessMiniBatch,...
    MiniBatchFormat = ["SSCB",""]);

mbqValidation = minibatchqueue(dsValidation, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@preprocessMiniBatch,...
    MiniBatchFormat = ["SSCB",""]);

Remove the classification layer from the trained network and convert it to a dlnetwork object.

lgraph = layerGraph(trainedNet);
lgraph = removeLayers(lgraph,"classoutput");

dlnet = dlnetwork(lgraph);

Prune Network

Use Taylor pruning to reduce the size of the network. Taylor pruning removes convolution filters based on an importance score calculated by the network and assigned to each filter, which reduces the overall network size over several pruning iterations. To learn more about Taylor pruning, see Prune Image Classification Network Using Taylor Scores.

Convert the network to a taylorPrunableNetwork object.

prunableNet = taylorPrunableNetwork(dlnet);

Specify Pruning and Fine-Tuning Options

Set the pruning options.

  • maxPruningIterations sets the maximum number of iterations to be used for pruning process.

  • maxToPrune is set as the maximum number of filters to be pruned in each iteration of the pruning cycle.

maxPruningIterations = 16; 
maxToPrune = 8; 
maxPrunableFilters = prunableNet.NumPrunables;
numTest = size(TValidation,1);
minPrunables = 5;

Set the fine-tuning options.

learnRate = 1e-2;
momentum = 0.9;
numMinibatchUpdates = 50;
validationFrequency = 1;

Prune Network

Prune the network. The supporting function taylorPruningLoop defines the pruning iterations for each mini-batch. Each pruning iteration performs these steps:

  1. Evaluate the pruning activations, gradients of the pruning activations, model gradients, state, and loss.

  2. Update the network state.

  3. Update the network parameters according to the optimizer.

  4. Compute first-order Taylor scores and accumulate scores across previous batches of data.

  5. Display progress.

prunableNet = taylorPruningLoop(prunableNet, mbqTrain, mbqValidation, classes, numTest, maxPruningIterations, ...
                                maxPrunableFilters, maxToPrune, minPrunables, learnRate, ...
                                momentum, numMinibatchUpdates, validationFrequency,trainAccuracy);

During each pruning iteration, the validation accuracy is often reduced due to changes in the network structure when the convolutional filters are pruned. To regain accuracy, you can retrain the network.

Retrain Pruned Network

Retrain the network after pruning to regain any loss in accuracy. First, reassemble the pruned network and attach the classification layer from the original network using the supporting function reassembleTaylorNetwork.

prunedLayerGraph =  reassembleTaylorNetwork(prunableNet, classes);

Specify training options.

miniBatchSize = 128;
validationFrequency = floor(numel(TTrain)/miniBatchSize);
options = trainingOptions("sgdm", ...
    InitialLearnRate=3e-4, ...
    MaxEpochs=15, ...
    MiniBatchSize=miniBatchSize, ...
    Shuffle="every-epoch", ...
    Plots="training-progress", ...
    Verbose=false, ...
    ValidationData={XValidation,TValidation}, ...
    OutputNetwork="best-validation-loss",...
    ValidationFrequency=validationFrequency);

Train the network.

validationFrequency = floor(numel(YTrain)/miniBatchSize);
trainedNetPruned = trainNetwork(XTrain,TTrain,prunedLayerGraph,options);

Evaluate Pruned Network

Calculate the accuracy of the pruned network after retraining.

YValidation = classify(trainedNetPruned,XValidation);
validationAccuracy = mean(YValidation == TValidation);
YTrain = classify(trainedNetPruned,XTrain);
trainAccuracy = mean(YTrain == TTrain);

disp(["Training Accuracy: " + trainAccuracy*100 + "%";"Validation Accuracy: " + validationAccuracy*100 + "%"])
    "Training Accuracy: 93.6128%"
    "Validation Accuracy: 92.3288%"

Plot the confusion matrix for the validation set.

figure(Units="normalized",Position=[0.2,0.2,0.5,0.5]);
cm = confusionchart(TValidation,YValidation, ...
    Title="Pruned Network", ...
    ColumnSummary="column-normalized",RowSummary="row-normalized");
sortClasses(cm,[commands,"unknown","background"])

Compare the accuracy of the original network and the pruned network. Having removed several convolutional network filters with Taylor pruning, you can see a small decrease in accuracy.

You can interpret this decrease as trimming the model and removing predictive capacity that does not specifically help distinguish your chosen keywords from other inputs.

This is the desired outcome of pruning as you aim to reduce model size and complexity as much as possible while preserving the ability to perform your desired task.

Quantize Pruned Network

Create a dlquantizer object from the pruned network, and specify the ExecutionEnvironment property as 'GPU' to prepare for deployment to a GPU device.

dlquantObj = dlquantizer(trainedNetPruned,ExecutionEnvironment='GPU');

Collect calibration statistics. Use the supporting function createCalibrationSet to create a representative calibration datastore with elements from each label in the training data.

calData = createCalibrationSet(XTrain, TTrain, 36, ["yes","no","up","down","left","right","on","off","stop","go","unknown","background"]);
calibrate(dlquantObj, calData);

Quantize the network with the quantize function.

qnetPruned = quantize(dlquantObj,'ExponentScheme','Histogram');
save("qnet","qnetPruned")
qDetails = quantizationDetails(qnetPruned)
qDetails = struct with fields:
            IsQuantized: 1
          TargetLibrary: "cudnn"
    QuantizedLayerNames: [20×1 string]
    QuantizedLearnables: [10×3 table]

Evaluate Quantized Network

Calculate the accuracy of the quantized pruned network.

YValidation = classify(qnetPruned,XValidation);
validationAccuracy = mean(YValidation == TValidation);
YTrain = classify(qnetPruned,XTrain);
trainAccuracy = mean(YTrain == TTrain);

disp(["Training Accuracy: " + trainAccuracy*100 + "%";"Validation Accuracy: " + validationAccuracy*100 + "%"])
    "Training Accuracy: 93.5636%"
    "Validation Accuracy: 92.3288%"

Plot the confusion matrix for the validation set.

figure(Units="normalized",Position=[0.2,0.2,0.5,0.5]);
cm = confusionchart(TValidation,YValidation, ...
    Title="Pruned and Quantized Network", ...
    ColumnSummary="column-normalized",RowSummary="row-normalized");
sortClasses(cm,[commands,"unknown","background"])

Compare the accuracy of the pruned network before and after quantization. The overall network accuracy decreases only a small amount through the process of pruning and quantizing the network.

Evaluate Network Compression

Use the estimateNetworkMetrics function to generate network metrics for the original network, the pruned network, and the quantized network.

originalNetMetrics = estimateNetworkMetrics(trainedNet);
taylorNetMetrics = estimateNetworkMetrics(trainedNetPruned);
quantizedNetMetrics = estimateNetworkMetrics(qnetPruned);

Evaluate the impact of each stage of compression on the number of learnables in the network.

Extract the number of learnable parameters in each network and visualize them in a bar plot. This plot demonstrates the reduction in the number of learnables in each stage of model compression. The reduction in the number of learnables comes from filters removed in Taylor pruning. No reduction happens during quantization.

figure
learnables = [sum(originalNetMetrics.NumberOfLearnables)
               sum(taylorNetMetrics.NumberOfLearnables)
               sum(quantizedNetMetrics.NumberOfLearnables)];

x = categorical({'Original','Taylor Pruned','Quantized'});
x = reordercats(x, string(x));
plotResults(x, learnables)
ylabel("Number of Learnables")
title("Number of Learnables in Network")

Evaluate the impact of each stage of compression on the parameter memory of the network.

Extract the parameter memory of each network and visualize them in a bar plot. Taylor pruning greatly reduces the parameter memory of the network. Quantization reduces it further. The combination of Taylor pruning and quantization compresses the deep learning network to meet reduced memory requirements while largely maintaining the accuracy of the deep neural network.

figure;
memory = [sum(originalNetMetrics.("ParameterMemory (MB)"))
               sum(taylorNetMetrics.("ParameterMemory (MB)"))
               sum(quantizedNetMetrics.("ParameterMemory (MB)"))];
  
plotResults(x, memory)
ylabel("Parameter Memory (MB)")
title("Parameter Memory of Network")

Supporting Functions

Create Training and Validation Datastores

function [adsTrain, adsValidation] = createDatastores(dataset, commands)

    %Create training datastore
    ads = audioDatastore(fullfile(dataset,"train"), ...
    IncludeSubfolders=true, ...
    FileExtensions=".wav", ...
    LabelSource="foldernames");

    background = categorical("background");

    isCommand = ismember(ads.Labels,commands);
    isBackground = ismember(ads.Labels,background);
    isUnknown = ~(isCommand|isBackground);

    includeFraction = 0.2; % Fraction of unknowns to include.
    idx = find(isUnknown);
    idx = idx(randperm(numel(idx),round((1-includeFraction)*sum(isUnknown))));
    isUnknown(idx) = false;

    ads.Labels(isUnknown) = categorical("unknown");

    adsTrain = subset(ads,isCommand|isUnknown|isBackground);
    adsTrain.Labels = removecats(adsTrain.Labels);

    %Create validation datastore
    ads = audioDatastore(fullfile(dataset,"validation"), ...
    IncludeSubfolders=true, ...
    FileExtensions=".wav", ...
    LabelSource="foldernames");

    isCommand = ismember(ads.Labels,commands);
    isBackground = ismember(ads.Labels,background);
    isUnknown = ~(isCommand|isBackground);

    includeFraction = 0.2; % Fraction of unknowns to include.
    idx = find(isUnknown);
    idx = idx(randperm(numel(idx),round((1-includeFraction)*sum(isUnknown))));
    isUnknown(idx) = false;

    ads.Labels(isUnknown) = categorical("unknown");

    adsValidation = subset(ads,isCommand|isUnknown|isBackground);
    adsValidation.Labels = removecats(adsValidation.Labels);



end

Extract Features

function [XTrain, XValidation, TTrain, TValidation] = extractFeatures(adsTrain, adsValidation)

    fs = 16e3; % Known sample rate of the data set.

    segmentDuration = 1;
    frameDuration = 0.025;
    hopDuration = 0.010;

    FFTLength = 512;
    numBands = 50;

    segmentSamples = round(segmentDuration*fs);
    frameSamples = round(frameDuration*fs);
    hopSamples = round(hopDuration*fs);
    overlapSamples = frameSamples - hopSamples;

    %Create an audioFeatureExtractor object to perform the feature extraction.
    afe = audioFeatureExtractor( ...
        SampleRate=fs, ...
        FFTLength=FFTLength, ...
        Window=hann(frameSamples,"periodic"), ...
        OverlapLength=overlapSamples, ...
        barkSpectrum=true);
    setExtractorParameters(afe,"barkSpectrum",NumBands=numBands,WindowNormalization=false);

    % Pad the audio to a consistent length, extract the features, and then apply a logarithm.  
    transform1 = transform(adsTrain,@(x)[zeros(floor((segmentSamples-size(x,1))/2),1);x;zeros(ceil((segmentSamples-size(x,1))/2),1)]);
    transform2 = transform(transform1,@(x)extract(afe,x));
    transform3 = transform(transform2,@(x){log10(x+1e-6)});

    % Read all data from the datastore. The output is a numFiles-by-1 cell array. Each element corresponds to the auditory spectrogram extracted from a file.
    XTrain = readall(transform3);

    numFiles = numel(XTrain);
    numFiles = 28463;
    [numHops,numBands,numChannels] = size(XTrain{1});
    numHops = 98;
    numBands = 50;
    numChannels = 1;

    %Convert the cell array to a 4-dimensional array with auditory spectrograms along the fourth dimension.
    XTrain = cat(4,XTrain{:});
    [numHops,numBands,numChannels,numFiles] = size(XTrain);
    numHops = 98;
    numBands = 50;
    numChannels = 1;
    numFiles = 28463;

    % Perform the feature extraction steps described above on the validation set.
    transform1 = transform(adsValidation,@(x)[zeros(floor((segmentSamples-size(x,1))/2),1);x;zeros(ceil((segmentSamples-size(x,1))/2),1)]);
    transform2 = transform(transform1,@(x)extract(afe,x));
    transform3 = transform(transform2,@(x){log10(x+1e-6)});
    XValidation = readall(transform3);
    XValidation = cat(4,XValidation{:});

    TTrain = adsTrain.Labels;
    TValidation = adsValidation.Labels;

end

Augment Dataset With Background Noise

function augmentDataset(datasetloc)
adsBkg = audioDatastore(fullfile(datasetloc,"background"));
fs = 16e3; % Known sample rate of the data set
segmentDuration = 1;
segmentSamples = round(segmentDuration*fs);

volumeRange = log10([1e-4,1]);

numBkgSegments = 4000;
numBkgFiles = numel(adsBkg.Files);
numSegmentsPerFile = floor(numBkgSegments/numBkgFiles);

fpTrain = fullfile(datasetloc,"train","background");
fpValidation = fullfile(datasetloc,"validation","background");

if ~datasetExists(fpTrain)
    
    % Create directories
    mkdir(fpTrain)
    mkdir(fpValidation)
    
    for backgroundFileIndex = 1:numel(adsBkg.Files)
        [bkgFile,fileInfo] = read(adsBkg);
        [~,fn] = fileparts(fileInfo.FileName);
        
        % Determine starting index of each segment
        segmentStart = randi(size(bkgFile,1)-segmentSamples,numSegmentsPerFile,1);
        
        % Determine gain of each clip
        gain = 10.^((volumeRange(2)-volumeRange(1))*rand(numSegmentsPerFile,1) + volumeRange(1));
        
        for segmentIdx = 1:numSegmentsPerFile
            
            % Isolate the randomly chosen segment of data.
            bkgSegment = bkgFile(segmentStart(segmentIdx):segmentStart(segmentIdx)+segmentSamples-1);
            
            % Scale the segment by the specified gain.
            bkgSegment = bkgSegment*gain(segmentIdx);
            
            % Clip the audio between -1 and 1.
            bkgSegment = max(min(bkgSegment,1),-1);
            
            % Create a file name.
            afn = fn + "_segment" + segmentIdx + ".wav";
            
            % Randomly assign background segment to either the train or
            % validation set.
            if rand > 0.85 % Assign 15% to validation
                dirToWriteTo = fpValidation;
            else % Assign 85% to train set.
                dirToWriteTo = fpTrain;
            end
            
            % Write the audio to the file location.
            ffn = fullfile(dirToWriteTo,afn);
            audiowrite(ffn,bkgSegment,fs)
            
        end
        
        % Print progress
        fprintf('Progress = %d (%%)\n',round(100*progress(adsBkg)))
        
    end
end
end

Mini-Batch Preprocessing Function

The preprocessMiniBatchTraining function preprocesses a mini-batch of predictors and labels for loss computation during training.

function [X, Y] = preprocessMiniBatch(XCell, YCell)
    % Concatenate.
    X = cat(4,XCell{:});

    % Extract label data from cell and concatenate.
    Y = cat(2,YCell{:});

    % One-hot encode labels.
    Y = onehotencode(Y,1);
end

Taylor Pruning Loop Function

function prunableNet = taylorPruningLoop(prunableNet, mbqTrain, mbqValidation, classes, ...
    numTest, maxPruningIterations, maxPrunableFilters, maxToPrune, minPrunables, learnRate, ...
    momentum, numMinibatchUpdates, validationFrequency,trainAccuracy)
    %Initialize plots used and perform Taylor pruning with custom loop

    accuracyOfOriginalNet = trainAccuracy*100;

    %Initialize Progress Plots
    figure("Position",[10,10,700,700])
    tl = tiledlayout(3,1);
    lossAx = nexttile;
    lineLossFinetune = animatedline(Color=[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Fine-Tuning Iteration")
    ylabel("Loss")
    grid on
    title("Mini-Batch Loss During Pruning")
    xTickPos = [];

    accuracyAx = nexttile;
    lineAccuracyPruning = animatedline(Color=[0.098 0.325 0.85],LineWidth=2,Marker="o");
    ylim([50 100])
    xlabel("Pruning Iteration")
    ylabel("Accuracy")
    grid on
    addpoints(lineAccuracyPruning,0,accuracyOfOriginalNet)
    title("Validation Accuracy After Pruning")

    numPrunablesAx = nexttile;
    lineNumPrunables = animatedline(Color=[0.4660 0.6740 0.1880],LineWidth=2,Marker="^");
    ylim([0 maxPrunableFilters])
    xlabel("Pruning Iteration")
    ylabel("Prunable Filters")
    grid on
    addpoints(lineNumPrunables,0,double(maxPrunableFilters))
    title("Number of Prunable Convolution Filters After Pruning")

    start = tic;
    iteration = 0;

    for pruningIteration = 1:maxPruningIterations

        % Shuffle data.
        shuffle(mbqTrain);

        % Reset the velocity parameter for the SGDM solver in every pruning
        % iteration.
        velocity = [];

        % Loop over mini-batches.
        fineTuningIteration = 0;
        while hasdata(mbqTrain)
            iteration = iteration + 1;
            fineTuningIteration = fineTuningIteration + 1;

            % Read mini-batch of data.
            [X, T] = next(mbqTrain);

            % Evaluate the pruning activations, gradients of the pruning
            % activations, model gradients, state, and loss using the dlfeval and
            % modelLossPruning functions.
            [loss,pruningActivations, pruningGradients, netGradients, state] = ...
                dlfeval(@modelLossPruning, prunableNet, X, T);

            % Update the network state.
            prunableNet.State = state;

            % Update the network parameters using the SGDM optimizer.
            [prunableNet, velocity] = sgdmupdate(prunableNet, netGradients, velocity, learnRate, momentum);

            % Compute first-order Taylor scores and accumulate the score across
            % previous mini-batches of data.
            prunableNet = updateScore(prunableNet, pruningActivations, pruningGradients);

            % Display the training progress.
            D = duration(0,0,toc(start),Format="hh:mm:ss");
            addpoints(lineLossFinetune, iteration, double(loss))
            title(tl,"Processing Pruning Iteration: " + pruningIteration + " of " + maxPruningIterations + ...
                ", Elapsed Time: " + string(D))
            % Synchronize the x-axis of the accuracy and numPrunables plots with the loss plot.
            xlim(accuracyAx,lossAx.XLim)
            xlim(numPrunablesAx,lossAx.XLim)
            drawnow

            % Stop the fine-tuning loop when numMinibatchUpdates is reached.
            if (fineTuningIteration > numMinibatchUpdates)
                break
            end
        end

        % Prune filters based on previously computed Taylor scores.
        prunableNet = updatePrunables(prunableNet, MaxToPrune = maxToPrune);

        % Show results on the validation data set in a subset of pruning iterations.
        isLastPruningIteration = pruningIteration == maxPruningIterations;
        if (mod(pruningIteration, validationFrequency) == 0 || isLastPruningIteration)
            accuracy = modelAccuracy(prunableNet, mbqValidation, classes, numTest);
            addpoints(lineAccuracyPruning, iteration, accuracy)
            addpoints(lineNumPrunables,iteration,double(prunableNet.NumPrunables))
        end
    
        % Set x-axis tick values at the end of each pruning iteration.
        xTickPos = [xTickPos, iteration]; %#ok<AGROW>
        xticks(lossAx,xTickPos)
        xticks(accuracyAx,[0,xTickPos])
        xticks(numPrunablesAx,[0,xTickPos])
        xticklabels(accuracyAx,["Unpruned",string(1:pruningIteration)])
        xticklabels(numPrunablesAx,["Unpruned",string(1:pruningIteration)])
        drawnow

        % Break if number of prunables is less than parameter
        if (prunableNet.NumPrunables < minPrunables)
            break
        end

    end
end

Model Loss Pruning Function

The modelLossPruning function is called within the Taylor pruning loop and performs a forward pass that returns several parameters related to pruning activations, gradients of the pruning activations, model gradients, state, and loss.

function [loss,pruningGradient,pruningActivations,netGradients,state] = modelLossPruning(prunableNet, X, Y)

    %Forward pass
    [pred,state,pruningActivations] = forward(prunableNet,X);

    %Compute cross-entropy
    loss = crossentropy(pred,Y);

    [pruningGradient,netGradients] = dlgradient(loss,pruningActivations,prunableNet.Learnables);
end

Model Accuracy Function

The modelAccuracy function is called within the Taylor pruning loop and computes the model accuracy of the dlnetwork on the minibatchque object 'mbq'.

function accuracy = modelAccuracy(net, mbq, classes, numObservations)

    totalCorrect = 0;
    reset(mbq);

    while hasdata(mbq)
        [dlX, Y] = next(mbq);

        dlYPred = extractdata(predict(net, dlX));

        YPred = onehotdecode(dlYPred,classes,1)';
        YReal = onehotdecode(Y,classes,1)';

        miniBatchCorrect = nnz(YPred == YReal);

        totalCorrect = totalCorrect + miniBatchCorrect;
    end

    accuracy = totalCorrect / numObservations * 100;
end

Reassemble Pruned Network

Convert the taylorPrunableNetwork to layer graph and reattach the classification layer.

function prunedLayerGraph =  reassembleTaylorNetwork(prunableNet, classes)
    
    prunedNet = dlnetwork(prunableNet);
    prunedLayerGraph = layerGraph(prunedNet);

    %add classification layer from classes defined in training data
    lgraphUpdated = addLayers(prunedLayerGraph, classificationLayer(Classes=classes));
    prunedLayerGraph = connectLayers(lgraphUpdated,prunedNet.OutputNames{1},string(lgraphUpdated.OutputNames{1}));

end

Create Calibration Data Set

Create a calibration dataset containing n elements from each label given training data.

function XCalibration = createCalibrationSet(XTrain, TTrain, n, labels)
XCalibration = [];
for i=1:numel(labels)
    %Find logical index of label
    idx = (TTrain == labels(i));
    %Create subset data corresponding to logical indices
    label_subset = XTrain(:,:,:,idx);
    first_n_labels = label_subset(:,:,:,1:n);
    %Concatenate
    XCalibration = cat(4, XCalibration, first_n_labels);
end
end

Plot Results Function

Return specifications for the bar plots used to evaluate network compression.

function plotResults(x, data)
    b = bar(x, data);
    b.FaceColor = 'flat';
    b.CData(1, :) = [0 0.9 1];
    b.CData(2, :) = [0 0.8 0.8];
    b.CData(3, :) = [0.8 0 0.8];
end

References

[1] Warden, Pete. "Speech Commands: A public dataset for single-word speech recognition", 2017. Available from https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Copyright Google 2017. The Speech Commands Dataset is licensed under the Creative Commons Attribution 4.0 license, available here: https://creativecommons.org/licenses/by/4.0/legalcode.