Contenuto principale

Sequential Feature Selection for Audio Features

This example shows a typical workflow for feature selection applied to the task of spoken digit recognition.

In sequential feature selection, you train a network on a given feature set and then incrementally add or remove features until the highest accuracy is reached [1]. In this example, you apply sequential forward selection to the task of spoken digit recognition using the Free Spoken Digit Dataset [2].

Streaming Spoken Digit Recognition

To motivate the example, begin by loading a pretrained network, the audioFeatureExtractor (Audio Toolbox) object used to train the network.

load("network_Audio_SequentialFeatureSelection.mat","bestNet","afe");

Create an audioDeviceReader (Audio Toolbox) to read audio from a microphone. Create three dsp.AsyncBuffer (DSP System Toolbox) objects: one to buffer audio read from your microphone, one to buffer short-term energy of the input audio for speech detection, and one to buffer predictions.

fs = afe.SampleRate;

deviceReader = audioDeviceReader(SampleRate=fs,SamplesPerFrame=256);

audioBuffer = dsp.AsyncBuffer(fs*3);
steBuffer = dsp.AsyncBuffer(1000);
predictionBuffer = dsp.AsyncBuffer(5);

Create a plot to display the streaming audio, the probability the network outputs during inference, and the prediction.

fig = figure;

streamAxes = subplot(3,1,1);
streamPlot = plot(zeros(fs,1));
ylabel("Amplitude")
xlabel("Time (s)")
title("Audio Stream")
streamAxes.XTick = [0,fs];
streamAxes.XTickLabel = [0,1];
streamAxes.YLim = [-1,1];

analyzedAxes = subplot(3,1,2);
analyzedPlot = plot(zeros(fs/2,1));
title("Analyzed Segment")
ylabel("Amplitude")
xlabel("Time (s)")
set(gca,XTickLabel=[])
analyzedAxes.XTick = [0,fs/2];
analyzedAxes.XTickLabel = [0,0.5];
analyzedAxes.YLim = [-1,1];

probabilityAxes = subplot(3,1,3);
probabilityPlot = bar(0:9,0.1*ones(1,10));
axis([-1,10,0,1])
ylabel("Probability")
xlabel("Class")

Perform streaming digit recognition (digits 0 through 9) for 20 seconds. While the loop runs, speak one of the digits and test its accuracy.

First, define a short-term energy threshold under which to assume a signal contains no speech.

steThreshold = 0.015;
idxVec = 1:fs;
tic
while toc < 20
    
    % Read in a frame of audio from your device.
    audioIn = deviceReader();
    
    % Write the audio into a the buffer.
    write(audioBuffer,audioIn);
    
    % While 200 ms of data is unused, continue this loop.
    while audioBuffer.NumUnreadSamples > 0.2*fs
        
        % Read 1 second from the audio buffer. Of that 1 second, 800 ms
        % is rereading old data and 200 ms is new data.
        audioToAnalyze = read(audioBuffer,fs,0.8*fs);
        
        % Update the figure to plot the current audio data.
        streamPlot.YData = audioToAnalyze;

        ste = mean(abs(audioToAnalyze));
        write(steBuffer,ste);
        if steBuffer.NumUnreadSamples > 5
            abc = sort(peek(steBuffer));
            steThreshold = abc(round(0.4*numel(abc)));
        end
        if ste > steThreshold
            
            % Use the detectSpeeech function to determine if a region of speech
            % is present.
            idx = detectSpeech(audioToAnalyze,fs);
            
            % If a region of speech is present, perform the following.
            if ~isempty(idx)
                % Zero out all parts of the signal except the speech
                % region, and trim to 0.5 seconds.
                audioToAnalyze = audioToAnalyze(idx(1,1):idx(1,2));
                audioToAnalyze = resize(audioToAnalyze,fs/2,Side="both");
                
                % Normalize the audio.
                audioToAnalyze = audioToAnalyze/max(abs(audioToAnalyze));
                
                % Update the analyzed segment plot
                analyzedPlot.YData = audioToAnalyze;

                % Extract the features.
                features = extract(afe,audioToAnalyze);
                
                % Call classify to determine the probabilities and the
                % winning label.
                features(isnan(features)|isinf(features)) = 0;
                scores = predict(bestNet,features);

                % Update the plot with the probabilities and the winning
                % label.
                probabilityPlot.YData = scores;
                write(predictionBuffer,scores);

                if predictionBuffer.NumUnreadSamples == predictionBuffer.Capacity
                    lastTen = peek(predictionBuffer);
                    [~,decision] = max(mean(lastTen.*hann(size(lastTen,1)),1));
                    probabilityAxes.Title.String = num2str(decision-1);
                end
            end
        else
            % If the signal energy is below the threshold, assume no speech
            % detected.
             probabilityAxes.Title.String = "";
             probabilityPlot.YData = 0.1*ones(10,1);
             analyzedPlot.YData = zeros(fs/2,1);
             reset(predictionBuffer)
        end
        
        drawnow limitrate
    end
end

The remainder of the example illustrates how the network used in the streaming detection was trained and how the features fed into the network were chosen.

Create Train and Validation Data Sets

Download the Free Spoken Digit Dataset (FSDD) [2]. FSDD consists of short audio files with spoken digits (0-9).

downloadFolder = matlab.internal.examples.downloadSupportFile("audio","FSDD.zip");
dataFolder = tempdir;
unzip(downloadFolder,dataFolder)
dataset = fullfile(dataFolder,"FSDD");

Create an audioDatastore (Audio Toolbox) to point to the recordings. Get the sample rate of the data set.

ads = audioDatastore(dataset,IncludeSubfolders=true);
[~,adsInfo] = read(ads);
fs = adsInfo.SampleRate;

The first element of the file names is the digit spoken in the file. Get the first element of the file names, convert them to categorical, and then set the Labels property of the audioDatastore.

[~,filenames] = cellfun(@(x)fileparts(x),ads.Files,UniformOutput=false);
ads.Labels = categorical(string(cellfun(@(x)x(1),filenames)));

To split the datastore into a development set and a validation set, use splitEachLabel (Audio Toolbox). Allocate 80% of the data for development and the remaining 20% for validation.

[adsTrain,adsValidation] = splitEachLabel(ads,0.8);

Set Up Audio Feature Extractor

Create an audioFeatureExtractor (Audio Toolbox) object to extract audio features over 30 ms windows with an update rate of 10 ms. Set all features you would like to test in this example to true.

win = hamming(round(0.03*fs),"periodic");
overlapLength = round(0.02*fs);

afe = audioFeatureExtractor( ...
    Window=win, ...
    OverlapLength=overlapLength, ...
    SampleRate=fs, ...
    ...
    linearSpectrum=false, ...
    melSpectrum=false, ...
    barkSpectrum=false, ...
    erbSpectrum=false, ...
    ...
    mfcc=true, ...
    mfccDelta=true, ...
    mfccDeltaDelta=true, ...
    gtcc=true, ...
    gtccDelta=true, ...
    gtccDeltaDelta=true, ...
    ...
    spectralCentroid=true, ...
    spectralCrest=true, ...
    spectralDecrease=true, ...
    spectralEntropy=true, ...
    spectralFlatness=true, ...
    spectralFlux=true, ...
    spectralKurtosis=true, ...
    spectralRolloffPoint=true, ...
    spectralSkewness=true, ...
    spectralSlope=true, ...
    spectralSpread=true, ...
    ...
    pitch=false, ...
    harmonicRatio=false, ...
    zerocrossrate=false, ...
    shortTimeEnergy=false);

Define Layers and Training Options

Define the List of Deep Learning Layers and trainingOptions used in this example. The first layer, sequenceInputLayer, is just a placeholder. Depending on which features you test during sequential feature selection, the first layer is replaced with a sequenceInputLayer of the appropriate size.

numUnits = 100;
layers = [ ...
    sequenceInputLayer(1)
    bilstmLayer(numUnits,OutputMode="last")
    fullyConnectedLayer(numel(categories(adsTrain.Labels)))
    softmaxLayer];

options = trainingOptions("adam", ...
    LearnRateSchedule="piecewise", ...
    Shuffle="every-epoch", ...
    Verbose=false, ...
    MaxEpochs=20, ...
    ResetInputNormalization=false);

Sequential Feature Selection

In the basic form of sequential feature selection, you train a network on a given feature set and then incrementally add or remove features until the accuracy no longer improves [1].

Forward Selection

Consider a simple case of forward selection on a set of four features. In the first forward selection loop, each of the four features are tested independently by training a network and comparing their validation accuracy. The feature that resulted in the highest validation accuracy is noted. In the second forward selection loop, the best feature from the first loop is combined with each of the remaining features. Now each pair of features is used for training. If the accuracy in the second loop did not improve over the accuracy in the first loop, the selection process ends. Otherwise, a new best feature set is selected. The forward selection loop continues until the accuracy no longer improves.

Backward Selection

In backward feature selection, you begin by training on a feature set that consists of all features and test whether or not accuracy improves as you remove features.

Run Sequential Feature Selection

The helper functions (sequentialFeatureSelection, trainAndValidateNetwork) implement forward or backward sequential feature selection. Specify the training datastore, validation datastore, audio feature extractor, network layers, network options, and direction. As a general rule, choose forward if you anticipate a small feature set or backward if you anticipate a large feature set.

direction = 'forward';
[logbook,bestFeatures,bestNet] = sequentialFeatureSelection(adsTrain,adsValidation,afe,layers,options,direction);

The logbook output from HelperFeatureExtractor is a table containing all feature configurations tested and the corresponding validation accuracy.

logbook
logbook=62×2 table
                           Features                            Accuracy
    _______________________________________________________    ________

    "mfcc, mfccDeltaDelta, gtccDelta"                              98  
    "mfcc, mfccDeltaDelta, gtcc, gtccDelta"                     97.25  
    "mfcc, gtccDelta, spectralFlux"                                97  
    "mfcc, gtccDelta, gtccDeltaDelta"                           96.75  
    "mfcc, mfccDeltaDelta, gtccDelta, gtccDeltaDelta"           96.75  
    "mfcc, mfccDeltaDelta, gtccDelta, spectralSlope"             96.5  
    "mfcc, gtccDelta"                                           96.25  
    "mfcc, mfccDelta, gtccDelta"                                96.25  
    "mfcc, mfccDeltaDelta, gtccDelta, spectralEntropy"          96.25  
    "mfccDelta, gtccDelta"                                         96  
    "gtccDelta, spectralRolloffPoint"                              96  
    "mfcc, gtccDelta, spectralRolloffPoint"                        96  
    "mfcc, mfccDeltaDelta, gtccDelta, spectralFlux"                96  
    "mfcc, mfccDelta, mfccDeltaDelta, gtccDelta"                95.75  
    "mfcc, mfccDeltaDelta, gtccDelta, spectralRolloffPoint"     95.75  
    "gtccDelta"                                                  95.5  
      ⋮

The bestFeatures output from sequentialFeatureSelection contains a structure with the optimal features set to true.

bestFeatures
bestFeatures = struct with fields:
                    mfcc: 1
               mfccDelta: 0
          mfccDeltaDelta: 1
                    gtcc: 0
               gtccDelta: 1
          gtccDeltaDelta: 0
        spectralCentroid: 0
           spectralCrest: 0
        spectralDecrease: 0
         spectralEntropy: 0
        spectralFlatness: 0
            spectralFlux: 0
        spectralKurtosis: 0
    spectralRolloffPoint: 0
        spectralSkewness: 0
           spectralSlope: 0
          spectralSpread: 0

You can set your audioFeatureExtractor using the structure.

set(afe,bestFeatures)
afe
afe = 
  audioFeatureExtractor with properties:

   Properties
                     Window: [240×1 double]
              OverlapLength: 160
                 SampleRate: 8000
                  FFTLength: []
    SpectralDescriptorInput: 'linearSpectrum'
        FeatureVectorLength: 39

   Enabled Features
     mfcc, mfccDeltaDelta, gtccDelta

   Disabled Features
     linearSpectrum, melSpectrum, barkSpectrum, erbSpectrum, mfccDelta, gtcc
     gtccDeltaDelta, spectralCentroid, spectralCrest, spectralDecrease, spectralEntropy, spectralFlatness
     spectralFlux, spectralKurtosis, spectralRolloffPoint, spectralSkewness, spectralSlope, spectralSpread
     pitch, harmonicRatio, zerocrossrate, shortTimeEnergy


   To extract a feature, set the corresponding property to true.
   For example, obj.mfcc = true, adds mfcc to the list of enabled features.

sequentialFeatureSelection also outputs the best performing network and the normalization factors that correspond to the chosen features. To save the network and configured audioFeatureExtractor, uncomment this line:

% save('network_Audio_SequentialFeatureSelection.mat','bestNet','afe')

Supporting Functions

Train and Validate Network

function [tLabels,predictedLabels,net] = trainAndValidateNetwork(adsTrain,adsValidation,afe,layers,options)
% Train and validate a network.
%
%   INPUTS:
%   adsTrain      - audioDatastore object that points to training set
%   adsValidation - audioDatastore object that points to validation set
%   afe           - audioFeatureExtractor object
%   layers        - Layers of LSTM or BiLSTM network
%   options       - trainingOptions object
%
%   OUTPUTS:
%   validationLabels - true labels of validation set
%   predictedLabels  - predicted labels of validation set
%   net              - trained network

% Copyright 2019-2023 The MathWorks, Inc.

fs = afe.SampleRate;

% Isolate the training and validation labels
labelsTrain = adsTrain.Labels;
tLabels = adsValidation.Labels;

% Extract features from the training set.
adsTrain = transform(adsTrain,@(x)resize(x,fs/2,Side="both"));
adsTrain = transform(adsTrain,@(x)x/max(abs(x),[],"all"));
adsTrain = transform(adsTrain,@(x){extract(afe,x)});
featuresTrain = readall(adsTrain,UseParallel=canUseParallelPool);

% Extract the features from the validation set.
adsValidation = transform(adsValidation,@(x)resize(x,fs/2,Side="both"));
adsValidation = transform(adsValidation,@(x)x/max(abs(x),[],"all"));
adsValidation = transform(adsValidation,@(x){extract(afe,x)});
featuresValidation = readall(adsValidation,UseParallel=canUseParallelPool);

% Use the training set to determine the mean and standard deviation of each
% feature. Normalize the training and validation sets.
allFeatures = cat(1,featuresTrain{:});
allFeatures(isinf(allFeatures)) = nan;
[S,M] = std(allFeatures,0,1,"omitnan");

% Update input layer for the number of features under test.
layers(1) = sequenceInputLayer(afe.FeatureVectorLength, ...
    Normalization="zscore",Mean=M',StandardDeviation=S');

% Train the network.
net = trainnet(featuresTrain,labelsTrain,layers,"crossentropy",options);

% Evaluate the network.
scores = minibatchpredict(net,featuresValidation,MiniBatchSize=numel(featuresValidation));
predictedLabels = scores2label(scores,unique(tLabels));

end

Sequential Feature Selection

function [logbook,bestFeatures,bestNet] = sequentialFeatureSelection(adsTrain,adsValidate,afeThis,layers,options,direction)
%
%   INPUTS:
%   adsTrain    - audioDatastore object that points to training set
%   adsValidate - audioDatastore object that points to validation set
%   afe         - audioFeatureExtractor object. Set all features to test to true
%   layers      - Layers of LSTM or BiLSTM network
%   options     - trainingOptions object
%   direction   - SFS direction, specify as 'forward' or 'backward'
%
%   OUTPUTS:
%   logbook         - table containing feature configurations tested and corresponding validation accuracies
%   bestFeatures    - structure containing best feature configuration
%   bestNet         - Trained network with highest validation accuracy

% Copyright 2019-2023 The MathWorks, Inc.

afe = copy(afeThis);
featuresToTest = fieldnames(info(afe));
N = numel(featuresToTest);
bestValidationAccuracy = 0;

% Set the initial feature configuration: all on for backward selection
% or all off for forward selection.
featureConfig = info(afe);
for i = 1:N
    if strcmpi(direction,"backward")
        featureConfig.(featuresToTest{i}) = true;
    else
        featureConfig.(featuresToTest{i}) = false;
    end
end

% Initialize logbook to track feature configuration and accuracy.
logbook = table(featureConfig,0,VariableNames=["Feature Configuration","Accuracy"]);

% Perform sequential feature evaluation.
wrapperIdx = 1;
bestAccuracy = 0;
while wrapperIdx <= N
    % Create a cell array containing all feature configurations to test
    % in the current loop.
    featureConfigsToTest = cell(numel(featuresToTest),1);
    for ii = 1:numel(featuresToTest)
        if strcmpi(direction,"backward")
            featureConfig.(featuresToTest{ii}) = false;
        else
            featureConfig.(featuresToTest{ii}) = true;
        end
        featureConfigsToTest{ii} = featureConfig;
        if strcmpi(direction,"backward")
            featureConfig.(featuresToTest{ii}) = true;
        else
            featureConfig.(featuresToTest{ii}) = false;
        end
    end

    % Loop over every feature set.
    for ii = 1:numel(featureConfigsToTest)

        % Determine the current feature configuration to test. Update
        % the feature afe.
        currentConfig = featureConfigsToTest{ii};
        set(afe,currentConfig)

        % Train and get k-fold cross-validation accuracy for current
        % feature configuration.
        [trueLabels,predictedLabels,net] = trainAndValidateNetwork(adsTrain,adsValidate,afe,layers,options);
        valAccuracy = mean(trueLabels==predictedLabels)*100;
        if valAccuracy > bestValidationAccuracy
            bestValidationAccuracy = valAccuracy;
            bestNet = net;
        end

        % Update Logbook
        result = table(currentConfig,valAccuracy,VariableNames=["Feature Configuration","Accuracy"]);
        logbook = [logbook;result]; %#ok<AGROW> 

    end

    % Determine and print the setting with the best accuracy. If accuracy
    % did not improve, end the run.
    [a,b] = max(logbook{:,"Accuracy"});
    if a <= bestAccuracy
        wrapperIdx = inf;
    else
        wrapperIdx = wrapperIdx + 1;
    end
    bestAccuracy = a;

    % Update the features-to-test based on the most recent winner.
    winner = logbook{b,"Feature Configuration"};
    fn = fieldnames(winner);
    tf = structfun(@(x)(x),winner);
    if strcmpi(direction,"backward")
        featuresToRemove = fn(~tf);
    else
        featuresToRemove = fn(tf);
    end
    for ii = 1:numel(featuresToRemove)
        loc =  strcmp(featuresToTest,featuresToRemove{ii});
        featuresToTest(loc) = [];
        if strcmpi(direction,"backward")
            featureConfig.(featuresToRemove{ii}) = false;
        else
            featureConfig.(featuresToRemove{ii}) = true;
        end
    end

end

% Sort the logbook and make it more readable.
logbook(1,:) = []; % Delete placeholder first row.
logbook = sortrows(logbook,"Accuracy","descend");
bestFeatures = logbook{1,"Feature Configuration"};
m = logbook{:,"Feature Configuration"};
fn = fieldnames(m);
myString = strings(numel(m),1);
for wrapperIdx = 1:numel(m)
    tf = structfun(@(x)(x),logbook{wrapperIdx,"Feature Configuration"});
    myString(wrapperIdx) = strjoin(fn(tf),", ");
end
logbook = table(myString,logbook{:,"Accuracy"},VariableNames=["Features","Accuracy"]);
end

References

[1] Jain, A., and D. Zongker. "Feature Selection: Evaluation, Application, and Small Sample Performance." IEEE Transactions on Pattern Analysis and Machine Intelligence. Vol. 19, Issue 2, 1997, pp. 153-158.

[2] Jakobovski. “Jakobovski/Free-Spoken-Digit-Dataset.” GitHub, May 30, 2019. https://github.com/Jakobovski/free-spoken-digit-dataset.