Sequential Feature Selection for Speech Emotion Recognition

This example shows a typical workflow for feature selection applied to the task of speech emotion recognition. You begin by creating a baseline accuracy using common audio features (MFCC). You then augment your data set to diminish overfitting. Finally, you perform sequential feature selection to select a better feature set.

In sequential feature selection, you train a network on a given feature set and then incrementally add or remove features until the highest accuracy is reached [1]. In this example, you apply sequential forward selection to the task of speech emotion recognition using the Berlin Database of Emotional Speech [2].

Download Data Set

The Berlin Database of Emotional Speech contains 535 utterances spoken by 10 actors intended to convey one of the following emotions: anger, boredom, disgust, anxiety/fear, happiness, sadness, or neutral. The emotions are text independent. Download the database from http://emodb.bilderbar.info/index-1280.html and then set PathToDatabase to the location of the audio files. Create an audioDatastore that points to the audio files.

datafolder = PathToDatabase;
ads = audioDatastore(fullfile(datafolder,"wav"));

The file names are codes indicating the speaker ID, text spoken, emotion, and version. The website contains a key for interpreting the code and additional information about the speakers such as gender and age. Create a table with the variables Speaker and Emotion. Decode the file names into the table.

filepaths = ads.Files;
emotionCodes = cellfun(@(x)x(end-5),filepaths,'UniformOutput',false);
emotions = replace(emotionCodes,{'W','L','E','A','F','T','N'}, ...
    {'Anger','Boredom','Disgust','Anxiety/Fear','Happiness','Sadness','Neutral'});

speakerCodes = cellfun(@(x)x(end-10:end-9),filepaths,'UniformOutput',false);
labelTable = cell2table([speakerCodes,emotions],'VariableNames',{'Speaker','Emotion'});
labelTable.Emotion = categorical(labelTable.Emotion);
labelTable.Speaker = categorical(labelTable.Speaker);
head(labelTable)
ans =

  8×2 table

    Speaker     Emotion 
    _______    _________

      03       Happiness
      03       Neutral  
      03       Anger    
      03       Happiness
      03       Neutral  
      03       Sadness  
      03       Anger    
      03       Anger    

labelTable is in the same order as the files in audioDatastore. Set the Labels property of the audioDatastore to the labelTable.

ads.Labels = labelTable;

You can now split by label and subset to isolate portions of the data. Subset a datastore that contains speaker 12 conveying boredom. Listen to the file and view the time-domain waveform. Display the full label corresponding to the utterance.

speaker = categorical("12");
emotion = categorical("Boredom");
adsSubset = subset(ads,ads.Labels.Speaker==speaker & ads.Labels.Emotion == emotion);

[audio,adsInfo] = read(adsSubset);
fs = adsInfo.SampleRate;
sound(audio,fs)

t = (0:size(audio,1)-1)/fs;
figure
plot(t,audio)
grid on
xlabel('Time (s)')
ylabel('Amplitude')

To provide an accurate assessment of the model you create in this example, train and validate using leave-one-speaker-out (LOSO) k-fold cross validation. In this method, you train using k-1 speakers and then validate on the left-out speaker. You repeat this procedure k times. The final validation accuracy is the average of the k folds.

Create a variable that contains the speaker IDs. Determine the number of folds: 1 for each speaker. The database contains utterances from 10 unique speakers. Use summary to display the speaker IDs (left column) and the number of utterances they contribute to the database (right column).

speaker = ads.Labels.Speaker;
numFolds = numel(speaker);
summary(speaker)
     03      49 
     08      58 
     09      43 
     10      38 
     11      55 
     12      35 
     13      61 
     14      69 
     15      56 
     16      71 

Generate Baseline Validation Accuracy

As a first step to developing a machine learning model, determine a baseline accuracy.

Assume that the 10-fold cross validation accuracy of a first attempt at training is about 60% due to insufficient training data, and that the model trained on the insufficient data overfits some folds and underfits others. To improve overall fit, increase the size of the dataset by 50 times using audioDataAugmenter.

Create an audioDataAugmenter object. Set the probability of applying pitch shifting to 0.5 and use the default range. Set the probability of applying time shifting to 1 and use a range of [-0.3,0.3] seconds. Set the probability of adding noise to 1 and specify the SNR range as [-20,40] dB.

augmenter = audioDataAugmenter('NumAugmentations',50, ...
    'TimeStretchProbability',0, ...
    'VolumeControlProbability',0, ...
    ...
    'PitchShiftProbability',0.5, ...
    ...
    'TimeShiftProbability',1, ...
    'TimeShiftRange',[-0.3,0.3], ...
    ...
    'AddNoiseProbability',1, ...
    'SNRRange', [-20,40]);

Create a new folder in your current folder to hold the augmented data set.

currentDir = pwd;
writeDirectory = [currentDir,'\augmentedData'];
mkdir(writeDirectory)

For each file in the audio datastore:

  1. Create 50 augmentations.

  2. Normalize the audio to have a max absolute value of 1.

  3. Write the augmented audio data as a WAV file. Append _augK to each of the file names, where K is the augmentation number. To speed up processing, use parfor and partition the datastore.

reset(ads)
numPartitions = 6;
for ii = 1:numPartitions
    adsPart = partition(ads,numPartitions,ii);
    while hasdata(adsPart)
        [x,adsInfo] = read(adsPart);
        data = augment(augmenter,x,fs);

        [~,fn] = fileparts(adsInfo.FileName);
        for i = 1:size(data,1)
            augmentedAudio = data.Audio{i};
            augmentedAudio = augmentedAudio/max(abs(augmentedAudio),[],'all');
            augNum = num2str(i);
            if numel(augNum)==1
                iString = ['0',augNum];
            else
                iString = augNum;
            end
            audiowrite([writeDirectory,'\',sprintf('%s_aug%s.wav',fn,iString)],augmentedAudio,fs);
        end
    end
end

Create an audio datastore that points to the augmented data set. Replicate the rows of the label table of the original datastore NumAugmentations times to determine the labels of the augmented datastore.

augads = audioDatastore(writeDirectory);
augads.Labels = repelem(ads.Labels,augmenter.NumAugmentations,1);

Mel-frequency cepstral coefficients (MFCC), delta-MFCC, and delta-delta MFCC are popular features for audio. As a baseline, use MFCC, delta-MFCC, and delta-delta MFCC with 30 ms windows and no overlap.

Create an audioFeatureExtractor object. Set Window to a periodic 30 ms Hamming window, OverlapLength to 0, and SampleRate to the sample rate of the database. Set mfcc, mfccDelta, and mfccDeltaDelta to true to extract them.

win = hamming(round(0.03*fs),"periodic");
overlapLength = 0;

extractor = audioFeatureExtractor( ...
    'Window',win, ...
    'OverlapLength',overlapLength, ...
    'SampleRate',fs, ...
    ...
    'mfcc',true, ...
    'mfccDelta',true, ...
    'mfccDeltaDelta',true)
extractor = 

  audioFeatureExtractor with properties:

   Properties
                     Window: [480×1 double]
              OverlapLength: 0
                 SampleRate: 16000
                  FFTLength: []
    SpectralDescriptorInput: 'linearSpectrum'

   Enabled Features
     mfcc, mfccDelta, mfccDeltaDelta

   Disabled Features
     linearSpectrum, melSpectrum, barkSpectrum, erbSpectrum, gtcc, gtccDelta
     gtccDeltaDelta, spectralCentroid, spectralCrest, spectralDecrease, spectralEntropy, spectralFlatness
     spectralFlux, spectralKurtosis, spectralRolloffPoint, spectralSkewness, spectralSlope, spectralSpread
     pitch, harmonicRatio


   To extract a feature, set the corresponding property to true.
   For example, obj.mfcc = true, adds mfcc to the list of enabled features.

The steps for each fold follow:

  1. Divide the audio datastore into training and validation sets.

  2. Extract feature vectors from the training and validation sets.

  3. Normalize the feature vectors.

  4. Buffer the feature vectors into sequences of 20 with overlaps of 10.

  5. Replicate the labels so that they are in one-to-one correspondence with the feature vectors.

  6. Define training options.

  7. Define the network.

  8. Train the network.

  9. Evaluate the network.

1. Divide the audio datastore into training and validation sets. For the development set, leave the first speaker out. For the validation set, use only utterances from the first speaker. Convert the data to tall arrays.

adsTrain = subset(augads,augads.Labels.Speaker~=speaker(1));
adsTrain.Labels = adsTrain.Labels.Emotion;
tallTrain = tall(adsTrain);

adsValidation = subset(ads,ads.Labels.Speaker==speaker(1));
adsValidation.Labels = adsValidation.Labels.Emotion;
tallValidation = tall(adsValidation);
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

2. Extract features from the training and validation sets. Reorient the features so that time is along rows to be compatible with sequenceInputLayer.

featuresTallTrain = cellfun(@(x)extract(extractor,x),tallTrain,"UniformOutput",false);
featuresTallTrain = cellfun(@(x)x',featuresTallTrain,"UniformOutput",false);
featuresTrain     = gather(featuresTallTrain);

featuresTallValidation = cellfun(@(x)extract(extractor,x),tallValidation,"UniformOutput",false);
featuresTallValidation = cellfun(@(x)x',featuresTallValidation,"UniformOutput",false);
featuresValidation = gather(featuresTallValidation);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 2 min 18 sec
Evaluation completed in 2 min 18 sec
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 1.2 sec
Evaluation completed in 1.2 sec

3. Use the training set to determine the mean and standard deviation of each feature. Normalize the training and validation sets.

allFeatures = cat(2,featuresTrain{:});
M = mean(allFeatures,2,'omitnan');
S = std(allFeatures,0,2,'omitnan');

featuresTrain = cellfun(@(x)(x-M)./S,featuresTrain,'UniformOutput',false);
featuresValidation = cellfun(@(x)(x-M)./S,featuresValidation,'UniformOutput',false);

4. Buffer the feature vectors into sequences so that each sequence consists of 20 feature vectors with overlaps of 10 feature vectors.

featureVectorsPerSequence = 20;
featureVectorOverlap = 10;
[sequencesTrain,sequencePerFileTrain] = HelperFeatureVector2Sequence(featuresTrain,featureVectorsPerSequence,featureVectorOverlap);
[sequencesValidation,sequencePerFileValidation] = HelperFeatureVector2Sequence(featuresValidation,featureVectorsPerSequence,featureVectorOverlap);

5. Replicate the labels of the training and validation sets so that they are in one-to-one correspondence with the sequences. Not all speakers have utterances for all emotions. Create an empty categorical array that contains all the emotional categories and append it to the validation labels so that the categorical array contains all emotions.

labelsTrain = repelem(adsTrain.Labels,[sequencePerFileTrain{:}]);

emptyEmotions = ads.Labels.Emotion;
emptyEmotions(:) = [];
labelsValidation = [emptyEmotions;adsValidation.Labels];
labelsValidation = repelem(labelsValidation,[sequencePerFileValidation{:}]);

6. Define a BiLSTM network using bilstmLayer. Place a dropoutLayer before and after the bilstmLayer to help prevent overfitting.

dropoutProb1 = 0.3;
numUnits = 200;
dropoutProb2 = 0.6;
layers = [ ...
    sequenceInputLayer(size(sequencesTrain{1},1))
    dropoutLayer(dropoutProb1)
    bilstmLayer(numUnits,"OutputMode","last")
    dropoutLayer(dropoutProb2)
    fullyConnectedLayer(numel(categories(emptyEmotions)))
    softmaxLayer
    classificationLayer];

7. Define training options using trainingOptions.

miniBatchSize = 512;
initialLearnRate = 0.005;
learnRateDropPeriod = 2;
maxEpochs = 3;
options = trainingOptions("adam", ...
    "MiniBatchSize",miniBatchSize, ...
    "InitialLearnRate",initialLearnRate, ...
    "LearnRateDropPeriod",learnRateDropPeriod, ...
    "LearnRateSchedule","piecewise", ...
    "MaxEpochs",maxEpochs, ...
    "Shuffle","every-epoch", ...
    "ValidationData",{sequencesValidation,labelsValidation}, ...
    "Verbose",false, ...
    "Plots","Training-Progress");

8. Train the network using trainNetwork.

net = trainNetwork(sequencesTrain,labelsTrain,layers,options);

9. Evaluate the network. Call classify to get the predicted labels per sequence. Get the mode of the predicted labels of each sequence to get the predicted labels of each file. Plot the confusion chart of the true labels and predicted labels.

predictedLabelsPerSequence = classify(net,sequencesValidation);

labelsTrue = adsValidation.Labels;
labelsPred = labelsTrue;
idx = 1;
for ii = 1:numel(labelsTrue)
    labelsPred(ii,:) = mode(predictedLabelsPerSequence(idx:idx + sequencePerFileValidation{ii} - 1,:),1);
    idx = idx + sequencePerFileValidation{ii};
end

figure
cm = confusionchart(labelsTrue,labelsPred);
valAccuracy = mean(labelsTrue==labelsPred)*100;
cm.Title = sprintf('Confusion Matrix for Fold 1\nAccuracy = %0.1f',valAccuracy);
sortClasses(cm,categories(emptyEmotions))
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

The helper function HelperTrainAndValidateNetwork performs the steps outlined above for all 10 folds and returns the true and predicted labels for each fold. Call HelperTrainAndValidateNetwork with the audioDatastore, the augmented audioDatastore, and the audioFeatureExtractor.

[labelsTrue,labelsPred] = HelperTrainAndValidateNetwork(ads,augads,extractor);

Print the accuracy per fold and plot the 10-fold confusion chart.

for ii = 1:numel(labelsTrue)
    foldAcc = mean(labelsTrue{ii}==labelsPred{ii})*100;
    fprintf('Fold %1.0f, Accuracy = %0.1f\n',ii,foldAcc);
end

labelsTrueMat = cat(1,labelsTrue{:});
labelsPredMat = cat(1,labelsPred{:});
figure
cm = confusionchart(labelsTrueMat,labelsPredMat);
valAccuracy = mean(labelsTrueMat==labelsPredMat)*100;
cm.Title = sprintf('Confusion Matrix for 10-Fold Cross-Validation\nAverage Accuracy = %0.1f',valAccuracy);
sortClasses(cm,categories(emptyEmotions))
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';
Fold 1, Accuracy = 87.8
Fold 2, Accuracy = 86.2
Fold 3, Accuracy = 72.1
Fold 4, Accuracy = 84.2
Fold 5, Accuracy = 76.4
Fold 6, Accuracy = 65.7
Fold 7, Accuracy = 68.9
Fold 8, Accuracy = 85.5
Fold 9, Accuracy = 75.0
Fold 10, Accuracy = 64.8

Sequential Feature Selection

Next, try to further improve accuracy by choosing a better feature set. Sequential feature selection can be time consuming. To reduce the feature selection time, reduce the augmented audio data set so that there are only 10 augmentations for each original file. Use this reduced data set to select features. Once the best set is chosen, you train on the full augmented data set, which is 50 times larger, for a final evaluation.

augads10 = subset(augads,1:5:numel(augads.Files));

Create a new audioFeatureExtractor object. Use the same window and overlap length as previously. Set all features you want to test to true.

extractor = audioFeatureExtractor( ...
    'Window',       win, ...
    'OverlapLength',overlapLength, ...
    'SampleRate',   fs, ...
    ...
    'linearSpectrum',      false, ...
    'melSpectrum',         false, ...
    'barkSpectrum',        false, ...
    'erbSpectrum',         false, ...
    ...
    'mfcc',                true, ...
    'mfccDelta',           true, ...
    'mfccDeltaDelta',      true, ...
    'gtcc',                true, ...
    'gtccDelta',           true, ...
    'gtccDeltaDelta',      true, ...
    ...
    'SpectralDescriptorInput','melSpectrum', ...
    'spectralCentroid',    true, ...
    'spectralCrest',       true, ...
    'spectralDecrease',    true, ...
    'spectralEntropy',     true, ...
    'spectralFlatness',    true, ...
    'spectralFlux',        true, ...
    'spectralKurtosis',    true, ...
    'spectralRolloffPoint',true, ...
    'spectralSkewness',    true, ...
    'spectralSlope',       true, ...
    'spectralSpread',      true, ...
    ...
    'pitch',               true, ...
    'harmonicRatio',       true);

Forward Selection

In forward selection, you start by evaluating each feature separately. Once the best single feature is realized, you hold it and evaluate each pair of features. Once the best two features are realized, you hold them and evaluate for the best three features, etc. When accuracy stops improving, you end the feature selection.

[logbook,bestFeatures] = ...
    HelperSFS(ads,augads10,extractor,'forward');

Inspect the top and bottom feature configurations.

head(logbook)
tail(logbook)
ans =

  8×2 table

                                Features                                 Accuracy
    _________________________________________________________________    ________

    "mfccDelta, gtcc, gtccDelta, spectralCrest"                           75.327 
    "mfccDelta, gtcc, gtccDelta, gtccDeltaDelta, spectralCrest"           74.393 
    "mfccDelta, gtcc, gtccDelta, spectralDecrease"                        74.019 
    "mfccDelta, gtcc, gtccDelta, spectralCrest, spectralRolloffPoint"     74.019 
    "mfccDelta, gtcc, gtccDelta, spectralCrest, harmonicRatio"            74.019 
    "mfccDelta, gtcc, gtccDelta"                                          73.458 
    "mfccDelta, gtcc, gtccDelta, spectralCentroid"                        73.458 
    "mfcc, mfccDelta, gtcc, gtccDelta"                                    73.271 


ans =

  8×2 table

           Features           Accuracy
    ______________________    ________

    "pitch"                    31.963 
    "spectralFlux"             31.589 
    "spectralEntropy"           27.85 
    "spectralRolloffPoint"     25.794 
    "spectralCrest"            24.486 
    "spectralDecrease"         24.486 
    "harmonicRatio"            23.364 
    "spectralFlatness"         21.121 

Test Selected Features on Augmented Data Set

Set the best feature configuration, as determined by sequential feature selection, on the audioFeatureExtractor object.

set(extractor,bestFeatures)

Test the LOSO 10-fold cross validation accuracy of the selected feature set using the full augmented data set.

[labelsTrue,labelsPred] = HelperTrainAndValidateNetwork(ads,augads,extractor);

labelsTrueMat = cat(1,labelsTrue{:});
labelsPredMat = cat(1,labelsPred{:});
figure
cm = confusionchart(labelsTrueMat,labelsPredMat);
valAccuracy = mean(labelsTrueMat==labelsPredMat)*100;
cm.Title = sprintf('Confusion Matrix for 10-Fold Cross-Validation\nAverage Accuracy = %0.1f',valAccuracy);
sortClasses(cm,categories(emptyEmotions))
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

References

[1] Jain, A., and D. Zongker. "Feature Selection: Evaluation, Application, and Small Sample Performance." IEEE Transactions on Pattern Analysis and Machine Intelligence. Vol. 19, Issue 2, 1997, pp. 153-158.

[2] Burkhardt, F., A. Paeschke, M. Rolfes, W.F. Sendlmeier, and B. Weiss, "A Database of German Emotional Speech." In Proceedings Interspeech 2005. Lisbon, Portugal: International Speech Communication Association, 2005.

Appendix -- Supporting Functions

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [sequences,sequencePerFile] = HelperFeatureVector2Sequence(features,featureVectorsPerSequence,featureVectorOverlap)
    % Copyright 2019 MathWorks, Inc.
    if featureVectorsPerSequence <= featureVectorOverlap
        error('The number of overlapping feature vectors must be less than the number of feature vectors per sequence.')
    end

    hopLength = featureVectorsPerSequence - featureVectorOverlap;
    idx1 = 1;
    sequences = {};
    sequencePerFile = cell(numel(features),1);
    for ii = 1:numel(features)
        sequencePerFile{ii} = floor((size(features{ii},2) - featureVectorsPerSequence)/hopLength) + 1;
        idx2 = 1;
        for j = 1:sequencePerFile{ii}
            sequences{idx1,1} = features{ii}(:,idx2:idx2 + featureVectorsPerSequence - 1); %#ok<AGROW>
            idx1 = idx1 + 1;
            idx2 = idx2 + hopLength;
        end
    end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [trueLabelsCrossFold,predictedLabelsCrossFold] = HelperTrainAndValidateNetwork(varargin)
    % Copyright 2019 The MathWorks, Inc.
    if nargin == 3
        ads = varargin{1};
        augads = varargin{2};
        extractor = varargin{3};
    elseif nargin == 2
        ads = varargin{1};
        augads = varargin{1};
        extractor = varargin{2};
    end
    speaker = categories(ads.Labels.Speaker);
    numFolds = numel(speaker);
    emptyEmotions = categorical(ads.Labels.Emotion);
    emptyEmotions(:) = [];

    % Loop over each fold
    trueLabelsCrossFold = {};
    predictedLabelsCrossFold = {};
    
    for i = 1:numFolds
        
        % 1. Divide the audio datastore into training and validation sets.
        % Convert the data to tall arrays.
        idxTrain           = augads.Labels.Speaker~=speaker(i);
        augadsTrain        = subset(augads,idxTrain);
        augadsTrain.Labels = augadsTrain.Labels.Emotion;
        tallTrain          = tall(augadsTrain);
        idxValidation        = ads.Labels.Speaker==speaker(i);
        adsValidation        = subset(ads,idxValidation);
        adsValidation.Labels = adsValidation.Labels.Emotion;
        tallValidation       = tall(adsValidation);

        % 2. Extract features from the training set. Reorient the features
        % so that time is along rows to be compatible with
        % sequenceInputLayer.
        tallTrain         = cellfun(@(x)x/max(abs(x),[],'all'),tallTrain,"UniformOutput",false);
        tallFeaturesTrain = cellfun(@(x)extract(extractor,x),tallTrain,"UniformOutput",false);
        tallFeaturesTrain = cellfun(@(x)x',tallFeaturesTrain,"UniformOutput",false);  %#ok<NASGU>
        [~,featuresTrain] = evalc('gather(tallFeaturesTrain)'); % Use evalc to suppress command-line output.
        tallValidation         = cellfun(@(x)x/max(abs(x),[],'all'),tallValidation,"UniformOutput",false);
        tallFeaturesValidation = cellfun(@(x)extract(extractor,x),tallValidation,"UniformOutput",false);
        tallFeaturesValidation = cellfun(@(x)x',tallFeaturesValidation,"UniformOutput",false); %#ok<NASGU>
        [~,featuresValidation] = evalc('gather(tallFeaturesValidation)'); % Use evalc to suppress command-line output.

        % 3. Use the training set to determine the mean and standard
        % deviation of each feature. Normalize the training and validation
        % sets.
        allFeatures = cat(2,featuresTrain{:});
        M = mean(allFeatures,2,'omitnan');
        S = std(allFeatures,0,2,'omitnan');
        featuresTrain = cellfun(@(x)(x-M)./S,featuresTrain,'UniformOutput',false);
        for ii = 1:numel(featuresTrain)
            idx = find(isnan(featuresTrain{ii}));
            if ~isempty(idx)
                featuresTrain{ii}(idx) = 0;
            end
        end
        featuresValidation = cellfun(@(x)(x-M)./S,featuresValidation,'UniformOutput',false);
        for ii = 1:numel(featuresValidation)
            idx = find(isnan(featuresValidation{ii}));
            if ~isempty(idx)
                featuresValidation{ii}(idx) = 0;
            end
        end

        % 4. Buffer the sequences so that each sequence consists of twenty
        % feature vectors with overlaps of 10 feature vectors.
        featureVectorsPerSequence = 20;
        featureVectorOverlap = 10;
        [sequencesTrain,sequencePerFileTrain] = HelperFeatureVector2Sequence(featuresTrain,featureVectorsPerSequence,featureVectorOverlap);
        [sequencesValidation,sequencePerFileValidation] = HelperFeatureVector2Sequence(featuresValidation,featureVectorsPerSequence,featureVectorOverlap);

        % 5. Replicate the labels of the train and validation sets so that
        % they are in one-to-one correspondence with the sequences.
        labelsTrain = [emptyEmotions;augadsTrain.Labels];
        labelsTrain = labelsTrain(:);
        labelsTrain = repelem(labelsTrain,[sequencePerFileTrain{:}]);

        % 6. Define a BiLSTM network.
        dropoutProb1 = 0.3;
        numUnits     = 200;
        dropoutProb2 = 0.6;
        layers = [ ...
            sequenceInputLayer(size(sequencesTrain{1},1))
            dropoutLayer(dropoutProb1)
            bilstmLayer(numUnits,"OutputMode","last")
            dropoutLayer(dropoutProb2)
            fullyConnectedLayer(numel(categories(emptyEmotions)))
            softmaxLayer
            classificationLayer];

        % 7. Define training options.
        miniBatchSize       = 512;
        initialLearnRate    = 0.005;
        learnRateDropPeriod = 2;
        maxEpochs           = 3;
        options = trainingOptions("adam", ...
            "MiniBatchSize",miniBatchSize, ...
            "InitialLearnRate",initialLearnRate, ...
            "LearnRateDropPeriod",learnRateDropPeriod, ...
            "LearnRateSchedule","piecewise", ...
            "MaxEpochs",maxEpochs, ...
            "Shuffle","every-epoch", ...
            "Verbose",false);

        % 8. Train the network.
        net = trainNetwork(sequencesTrain,labelsTrain,layers,options);

        % 9. Evaluate the network. Call classify to get the predicted labels
        % for each sequence. Get the mode of the predicted labels of each
        % sequence to get the predicted labels of each file.
        predictedLabelsPerSequence = classify(net,sequencesValidation);
        trueLabels = categorical(adsValidation.Labels);
        predictedLabels = trueLabels;
        idx1 = 1;
        for ii = 1:numel(trueLabels)
            predictedLabels(ii,:) = mode(predictedLabelsPerSequence(idx1:idx1 + sequencePerFileValidation{ii} - 1,:),1);
            idx1 = idx1 + sequencePerFileValidation{ii};
        end
        trueLabelsCrossFold{i} = trueLabels; %#ok<AGROW>
        predictedLabelsCrossFold{i} = predictedLabels; %#ok<AGROW>
    end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [logbook,bestFeatures] = HelperSFS(ads,adsAug,extractor,direction)
    % logbook = HelperSFS(ads,adsAug,extractor,direction)
    % returns a table, logbook, that contains the feature configurations tested
    % and associated validation accuracy.
    %   ads       - audioDatastore object that points to the original dataset (used for val).
    %   adsAug    - audioDatastore object that points to the augmented dataset (used for dev).
    %   extractor - audioFeatureExtractor object. Set all features to test to true.
    %   direction - specify as 'forward' or 'backward'
    %
    %[logbook,bestFeatures] = HelperSFS(ads,adsAug,extractor,direction)
    % also returns a struct, bestFeatures, containing the best feature
    % configuration for audioFeatureExtractor.

    % Copyright 2019 The MathWorks, Inc.

    featuresToTest = fieldnames(info(extractor));
    N = numel(featuresToTest);

    % ---------------------------------------------------------------------
    % Set the initial feature configuration: all on for backward selection
    % or all off for forward selection.
    featureConfig  = info(extractor);
    for i = 1:N
        if strcmpi(direction,"backward")
            featureConfig.(featuresToTest{i}) = true;
        else
            featureConfig.(featuresToTest{i}) = false;
        end
    end
    % ---------------------------------------------------------------------

    % Initialize logbook to track feature configuration and accuracy.
    logbook = table(featureConfig,0,'VariableNames',["Feature Configuration","Accuracy"]);

    %% Perform sequential feature evaluation
    wrapperIdx = 1;
    bestAccuracy = 0;
    while wrapperIdx <= N
        % -----------------------------------------------------------------
        % Create a cell array containing all feature configurations to test
        % in the current loop.
        featureConfigsToTest = cell(numel(featuresToTest),1);
        for ii = 1:numel(featuresToTest)
            if strcmpi(direction,"backward")
                featureConfig.(featuresToTest{ii}) = false;
            else
                featureConfig.(featuresToTest{ii}) = true;
            end
            featureConfigsToTest{ii} = featureConfig;
            if strcmpi(direction,"backward")
                featureConfig.(featuresToTest{ii}) = true;
            else
                featureConfig.(featuresToTest{ii}) = false;
            end
        end
        % -----------------------------------------------------------------

        % Loop over every feature set.
        for ii = 1:numel(featureConfigsToTest)

            % -------------------------------------------------------------
            % Determine the current feature configuration to test. Update
            % the feature extractor.
            currentConfig = featureConfigsToTest{ii};
            set(extractor,currentConfig)
            % -------------------------------------------------------------

            % -------------------------------------------------------------
            % Train and get k-fold cross-validation accuracy for current
            % feature configuration.
            [trueLabels,predictedLabels] = HelperTrainAndValidateNetwork(ads,adsAug,extractor);
            trueLabelsMat = cat(1,trueLabels{:});
            predictedLabelsMat = cat(1,predictedLabels{:});
            valAccuracy = mean(trueLabelsMat==predictedLabelsMat)*100;
            % -------------------------------------------------------------

            % Update Logbook ----------------------------------------------
            result = table(currentConfig,valAccuracy, ...
                'VariableNames',["Feature Configuration","Accuracy"]);
            logbook = [logbook;result];
            % -------------------------------------------------------------

        end

        % -----------------------------------------------------------------
        % Determine and print the setting with the best accuracy. If
        % accuracy did not improve, end the run.
        [a,b] = max(logbook{:,'Accuracy'});
        if a <= bestAccuracy
            wrapperIdx = inf;
        else
            wrapperIdx = wrapperIdx + 1;
        end
        bestAccuracy = a;
        % -----------------------------------------------------------------

        % -----------------------------------------------------------------
        % Update the features-to-test based on the most recent winner.
        winner = logbook{b,'Feature Configuration'};
        fn = fieldnames(winner);
        tf = structfun(@(x)(x),winner);
        if strcmpi(direction,"backward")
            featuresToRemove = fn(~tf);
        else
            featuresToRemove = fn(tf);
        end
        for ii = 1:numel(featuresToRemove)
            loc =  strcmp(featuresToTest,featuresToRemove{ii});
            featuresToTest(loc) = [];
            if strcmpi(direction,"backward")
                featureConfig.(featuresToRemove{ii}) = false;
            else
                featureConfig.(featuresToRemove{ii}) = true;
            end
        end
        % -----------------------------------------------------------------

    end
    
    % ---------------------------------------------------------------------
    % Sort the logbook and make it more readable
    logbook(1,:) = []; % Delete placeholder first row
    logbook = sortrows(logbook,{'Accuracy'},{'descend'});
    bestFeatures = logbook{1,'Feature Configuration'};
    m = logbook{:,'Feature Configuration'};
    fn = fieldnames(m);
    myString = strings(numel(m),1);
    for wrapperIdx = 1:numel(m)
        tf = structfun(@(x)(x),logbook{wrapperIdx,'Feature Configuration'});
        myString(wrapperIdx) = strjoin(fn(tf),", ");
    end
    logbook = table(myString,logbook{:,'Accuracy'},'VariableNames',["Features","Accuracy"]);
    % ---------------------------------------------------------------------
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%