Classify ECG Signals Using Long Short-Term Memory Networks with GPU Acceleration
This example shows how to classify heartbeat electrocardiogram (ECG) data from the PhysioNet 2017 Challenge using deep learning and signal processing. In particular, the example uses Long Short-Term Memory networks and time-frequency analysis with GPU acceleration. You must have Parallel Computing Toolbox™ and a supported GPU. For details, see GPU Computing Requirements (Parallel Computing Toolbox).
This example reproduces the exclusively CPU version of the time-frequency feature computations found in Classify ECG Signals Using Long Short-Term Memory Networks.
Introduction
ECGs record the electrical activity of a person's heart over a period of time. Physicians use ECGs to detect visually if a patient's heartbeat is normal or irregular.
Atrial fibrillation (AFib) is a type of irregular heartbeat that occurs when the heart's upper chambers, the atria, beat out of coordination with the lower chambers, the ventricles.
This example uses ECG data from the PhysioNet 2017 Challenge [1], [2], [3], which is available at https://physionet.org/challenge/2017/. The data consists of a set of ECG signals sampled at 300 Hz and divided by a group of experts into four different classes: Normal (N), AFib (A), Other Rhythm (O), and Noisy Recording (~). This example shows how to automate the classification process using deep learning. The procedure explores a binary classifier that can differentiate Normal ECG signals from signals showing signs of AFib.
This example uses long short-term memory (LSTM) networks, a type of recurrent neural network (RNN) well-suited to study sequence and time-series data. An LSTM network can learn long-term dependencies between time steps of a sequence. The LSTM layer (lstmLayer
(Deep Learning Toolbox)) can look at the time sequence in the forward direction, while the bidirectional LSTM layer (bilstmLayer
(Deep Learning Toolbox)) can look at the time sequence in both forward and backward directions. This example uses a bidirectional LSTM layer.
To accelerate feature extraction, training, and inference, this example uses a GPU and Parallel Computing Toolbox.
Load and Examine Data
Run the ReadPhysionetData script to download the data from the PhysioNet website and generate a MAT-file (PhysionetData.mat) that contains the ECG signals in the appropriate format. Downloading the data might take a few minutes. Use a conditional statement that runs the script only if PhysionetData.mat does not already exist in the current folder.
if ~isfile('PhysionetData.mat') ReadPhysionetData end load PhysionetData
The loading operation adds two variables to the workspace: Signals
and Labels
. Signals
is a cell array that holds the ECG signals. Labels
is a categorical array that holds the corresponding ground-truth labels of the signals.
Signals(1:5)'
ans=1×5 cell array
{1×9000 double} {1×9000 double} {1×18000 double} {1×9000 double} {1×18000 double}
Labels(1:5)
ans = 5×1 categorical
N
N
N
A
A
Use the summary
function to see how many AFib signals and Normal signals are contained in the data.
summary(Labels)
A 738 N 5050
Generate a histogram of signal lengths. Most of the signals are 9000 samples long.
L = cellfun(@length,Signals); h = histogram(L); xticks(0:3000:18000); xticklabels(0:3000:18000); title('Signal Lengths') xlabel('Length') ylabel('Count')
Visualize a segment of one signal from each class. AFib heartbeats are spaced out at irregular intervals while Normal heartbeats occur regularly. AFib heartbeat signals also often lack a P wave, which pulses before the QRS complex in a Normal heartbeat signal. The plot of the Normal signal shows a P wave and a QRS complex.
normal = Signals{1}; aFib = Signals{4}; subplot(2,1,1) plot(normal) title('Normal Rhythm') xlim([4000,5200]) ylabel('Amplitude (mV)') text(4330,150,'P','HorizontalAlignment','center') text(4370,850,'QRS','HorizontalAlignment','center') subplot(2,1,2) plot(aFib) title('Atrial Fibrillation') xlim([4000,5200]) xlabel('Samples') ylabel('Amplitude (mV)')
Prepare Data for Training
During training, the trainnet
function splits the data into mini-batches. The function then pads or truncates signals in the same mini-batch so they all have the same length. Too much padding or truncating can have a negative effect on the performance of the network, because the network might interpret a signal incorrectly based on the added or removed information.
To avoid excessive padding or truncating, apply the segmentSignals
function to the ECG signals so they are all 1000 samples long. The function ignores signals with fewer than 9000 samples. If a signal has more than 9000 samples, segmentSignals
breaks it into as many 9000-sample segments as possible and ignores the remaining samples. For example, a signal with 18500 samples becomes two 9000-sample signals, and the remaining 500 samples are ignored.
[Signals,Labels] = segmentSignals(Signals,Labels);
View the first five elements of the Signals
array to verify that each entry is now 1000 samples long.
Signals(1:5)'
ans=1×5 cell array
{1×9000 double} {1×9000 double} {1×9000 double} {1×9000 double} {1×9000 double}
Train Classifier Using Raw Signal Data
To design the classifier, use the raw signals generated in the previous section. Split the signals into a training set to train the classifier and a testing set to test the accuracy of the classifier on new data.
Use the summary
function to show that the ratio of AFib signals to Normal signals is 718:4937, or approximately 1:7.
summary(Labels)
A 718 N 4937
Because about 7/8 of the signals are Normal, the classifier would learn that it can achieve a high accuracy simply by classifying all signals as Normal. To avoid this bias, augment the AFib data by duplicating AFib signals in the dataset so that there is the same number of Normal and AFib signals. This duplication, commonly called oversampling, is one form of data augmentation used in deep learning.
Split the signals according to their class.
afibX = Signals(Labels=='A'); afibY = Labels(Labels=='A'); normalX = Signals(Labels=='N'); normalY = Labels(Labels=='N');
Next, use dividerand
to divide targets from each class randomly into training, validation and testing sets.
rng("default")
[trainIndA,validIndA,testIndA] = dividerand(length(afibX),0.8,0.1,0.1);
[trainIndN,validIndN,testIndN] = dividerand(length(normalX),0.8,0.1,0.1);
XTrainA = afibX(trainIndA);
YTrainA = afibY(trainIndA);
XTrainN = normalX(trainIndN);
YTrainN = normalY(trainIndN);
XValidA = afibX(validIndA);
YValidA = afibY(validIndA);
XValidN = normalX(validIndN);
YValidN = normalY(validIndN);
XTestA = afibX(testIndA);
YTestA = afibY(testIndA);
XTestN = normalX(testIndN);
YTestN = normalY(testIndN);
The dataset is imbalanced. To achieve a similar number of AFib and Normal signals, repeat the AFib signals seven times.
By default, the neural network randomly shuffles the data before training, ensuring that contiguous signals do not all have the same label.
XTrain = [repmat(XTrainA,7,1); XTrainN]; YTrain = [repmat(YTrainA,7,1); YTrainN]; XValid = [repmat(XValidA,7,1); XValidN]; YValid = [repmat(YValidA,7,1); YValidN]; XTest = [repmat(XTestA,7,1); XTestN]; YTest = [repmat(YTestA,7,1); YTestN];
The distribution between Normal and AFib signals is now evenly balanced in both the training set and the testing set.
summary(YTrain)
A 4018 N 3949
summary(YValid)
A 504 N 494
summary(YTest)
A 504 N 494
Define LSTM Network Architecture
LSTM networks can learn long-term dependencies between time steps of sequence data. This example uses the bidirectional LSTM layer bilstmLayer
, as it looks at the sequence in both forward and backward directions.
Because the input signals have one dimension each, specify the input size to be sequences of size 1. Specify a bidirectional LSTM layer with an output size of 50 and output the last element of the sequence. This command instructs the bidirectional LSTM layer to map the input time series into 50 features and then prepares the output for the fully connected layer. Finally, specify two classes by including a fully connected layer of size 2, followed by a softmax layer.
layers = [ ... sequenceInputLayer(1) bilstmLayer(50,'OutputMode','last') fullyConnectedLayer(2) softmaxLayer ]
layers = 4×1 Layer array with layers: 1 '' Sequence Input Sequence input with 1 dimensions 2 '' BiLSTM BiLSTM with 50 hidden units 3 '' Fully Connected 2 fully connected layer 4 '' Softmax softmax
Next specify the training options for the classifier. Set the 'MaxEpochs'
to 100 to allow the network to make 100 passes through the training data. A 'MiniBatchSize'
of 300 directs the network to look at 300 training signals at a time. An 'InitialLearnRate'
of 0.01 helps speed up the training process. Specify 'Plots'
as 'training-progress'
to generate plots that show a graphic of the training progress as the number of iterations increases. Set 'Verbose'
to false
to suppress the table output that corresponds to the data shown in the plot. If you want to see this table, set 'Verbose'
to true
. Because the training data has sequences with rows and columns corresponding to channels and time steps, respectively, specify the input data format 'CTB'
(channel, time, batch).
This example uses the adaptive moment estimation (ADAM) solver. ADAM performs better with RNNs like LSTMs than the default stochastic gradient descent with momentum (SGDM) solver.
options = trainingOptions('adam', ... 'MaxEpochs',150, ... 'MiniBatchSize', 200, ... 'GradientThreshold',1, ... 'Shuffle','every-epoch', ... 'InitialLearnRate', 1e-3, ... 'ExecutionEnvironment','gpu', ... 'plots','training-progress', ... 'Metrics','accuracy', ... 'InputDataFormats','CTB', ... 'ValidationData',{XValid,YValid}, ... 'Verbose',false, ... 'OutputNetwork','last-iteration');
Train LSTM Network
Train the LSTM network with the specified training options and layer architecture by using trainnet
(Deep Learning Toolbox). Because the training set is large, the training process can take several minutes.
net = trainnet(XTrain,YTrain,layers,'crossentropy',options);
The top subplot of the training-progress plot represents the training accuracy, which is the classification accuracy on each mini-batch. When training progresses successfully, this value typically increases towards 100%. The bottom subplot displays the training loss, which is the cross-entropy loss on each mini-batch. When training progresses successfully, this value typically decreases towards zero.
If the training is not converging, the plots might oscillate between values without trending in a certain upward or downward direction. This oscillation means that the training accuracy is not improving and the training loss is not decreasing. This situation can occur from the start of training, or the plots might plateau after some preliminary improvement in training accuracy. In many cases, changing the training options can help the network achieve convergence. Decreasing MiniBatchSize
or decreasing InitialLearnRate
might result in a longer training time, but it can help the network learn better.
Here, the training accuracy is very high, but the validation accuracy has not improved correspondingly. This likely indicates overfitting, meaning the model cannot generalize and fits too closely to the training dataset instead. There could be many reasons for this, such as the training data containing a lot of redundant and irrelevant information, and the network not learning the real key factors for classification.
Visualize Training and Testing Accuracy
Calculate the training accuracy, which represents the accuracy of the classifier on the signals on which it was trained. First, classify the training data.
To make predictions with multiple observations, use the minibatchpredict
(Deep Learning Toolbox) function. To convert the prediction scores to labels, use the scores2label
function. The minibatchpredict
function automatically uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the function uses the CPU.
classNames = categories(YTrain); scores = minibatchpredict(net,XTrain, ... 'InputDataFormats','CTB', ... 'ExecutionEnvironment','gpu'); trainPred = scores2label(scores,classNames);
In classification problems, confusion matrices are used to visualize the performance of a classifier on a set of data for which the true values are known. The Target Class is the ground-truth label of the signal, and the Output Class is the label assigned to the signal by the network. The axes labels represent the class labels, AFib (A) and Normal (N).
Use the confusionchart
command to calculate the overall classification accuracy for the testing data predictions. Specify 'RowSummary'
as 'row-normalized'
to display the true positive rates and false positive rates in the row summary. Also, specify 'ColumnSummary'
as 'column-normalized'
to display the positive predictive values and false discovery rates in the column summary.
LSTMAccuracy = sum(trainPred == YTrain)/numel(YTrain)*100
LSTMAccuracy = 99.0210
figure confusionchart(YTrain,trainPred,'ColumnSummary','column-normalized',... 'RowSummary','row-normalized','Title','Confusion Chart for LSTM');
Now classify the testing data with the same network.
scores = minibatchpredict(net,XTest, ... 'InputDataFormats','CTB', ... 'ExecutionEnvironment','gpu'); testPred = scores2label(scores,classNames);
Calculate the testing accuracy and visualize the classification performance as a confusion matrix.
LSTMAccuracy = sum(testPred == YTest)/numel(YTest)*100
LSTMAccuracy = 53.5070
figure confusionchart(YTest,testPred,'ColumnSummary','column-normalized',... 'RowSummary','row-normalized','Title','Confusion Chart for LSTM');
Improve Performance with Feature Extraction
Feature extraction from the data can help improve the performance of the classifier. To decide which features to extract, this example adapts an approach that computes time-frequency images, such as spectrograms, and uses them to train convolutional neural networks (CNNs) [4], [5].
Visualize the spectrogram of each type of signal.
fs = 300; figure subplot(2,1,1); pspectrum(normal,fs,'spectrogram','TimeResolution',0.5) title('Normal Signal') subplot(2,1,2); pspectrum(aFib,fs,'spectrogram','TimeResolution',0.5) title('AFib Signal')
Because this example uses an LSTM instead of a CNN, it is important to translate the approach so it applies to one-dimensional signals. Time-frequency (TF) moments extract information from the spectrograms. Each moment can be used as a one-dimensional feature to input to the LSTM.
Explore two TF moments in the time domain:
Instantaneous frequency (
instfreq
)Spectral entropy (
pentropy
)
The instfreq
function estimates the time-dependent frequency of a signal as the first moment of the power spectrogram. The function computes a spectrogram using short-time Fourier transforms over time windows. In this example, the function uses 255 time windows. The time outputs of the function correspond to the centers of the time windows.
Visualize the instantaneous frequency for each type of signal.
[instFreqA,tA] = instfreq(aFib,fs); [instFreqN,tN] = instfreq(normal,fs); figure subplot(2,1,1); plot(tN,instFreqN) title('Normal Signal') xlabel('Time (s)') ylabel('Instantaneous Frequency') subplot(2,1,2); plot(tA,instFreqA) title('AFib Signal') xlabel('Time (s)') ylabel('Instantaneous Frequency')
Convert the training and testing sets to gpuArray
objects to execute the instantaneous frequency computations on the GPU. Apply the instfreq
function to every cell in each set.
gpuXTrain = cellfun(@gpuArray,XTrain,'UniformOutput',false); instfreqTrain = cellfun(@(x)instfreq(x,fs),gpuXTrain,'UniformOutput',false); gpuXTest = cellfun(@gpuArray,XTest,'UniformOutput',false); instfreqTest = cellfun(@(x)instfreq(x,fs),gpuXTest,'UniformOutput',false); gpuXValid = cellfun(@gpuArray,XValid,'UniformOutput',false); instfreqValid = cellfun(@(x)instfreq(x,fs),gpuXValid,'UniformOutput',false);
The spectral entropy measures how spiky flat the spectrum of a signal is. A signal with a spiky spectrum, like a sum of sinusoids, has low spectral entropy. A signal with a flat spectrum, like white noise, has high spectral entropy. The pentropy
function estimates the spectral entropy based on a power spectrogram. As with the instantaneous frequency estimation case, pentropy
uses 255 time windows to compute the spectrogram. The time outputs of the function correspond to the center of the time windows.
Visualize the spectral entropy for each type of signal.
[pentropyA,tA2] = pentropy(aFib,fs); [pentropyN,tN2] = pentropy(normal,fs); figure subplot(2,1,1) plot(tN2,pentropyN) title('Normal Signal') ylabel('Spectral Entropy') subplot(2,1,2) plot(tA2,pentropyA) title('AFib Signal') xlabel('Time (s)') ylabel('Spectral Entropy')
Use cellfun
to apply the pentropy
function to every cell in the training, testing and validation sets.
pentropyTrain = cellfun(@(x)pentropy(x,fs),gpuXTrain,'UniformOutput',false); pentropyTest = cellfun(@(x)pentropy(x,fs),gpuXTest,'UniformOutput',false); pentropyValid = cellfun(@(x)pentropy(x,fs),gpuXValid,'UniformOutput',false);
Concatenate the features such that each cell in the new training and testing sets has two dimensions, or two features.
XTrain2 = cellfun(@(x,y)[x y]',instfreqTrain,pentropyTrain,'UniformOutput',false); XTest2 = cellfun(@(x,y)[x y]',instfreqTest,pentropyTest,'UniformOutput',false); XValid2 = cellfun(@(x,y)[x y]',instfreqValid,pentropyValid,'UniformOutput',false);
Visualize the format of the new inputs. Each cell no longer contains one 1000-sample-long signal; now it contains two 255-sample-long features.
XTrain2(1:5)
ans=5×1 cell array
{2×255 gpuArray}
{2×255 gpuArray}
{2×255 gpuArray}
{2×255 gpuArray}
{2×255 gpuArray}
Standardize Data
The instantaneous frequency and the spectral entropy have means that differ by almost one order of magnitude. Furthermore, the instantaneous frequency mean might be too high for the LSTM to learn effectively. When a network is fit on data with a large mean and a large range of values, large inputs could slow down the learning and convergence of the network [6].
mean(instFreqN)
ans = 5.5551
mean(pentropyN)
ans = 0.6324
Use the training set mean and standard deviation to standardize the training, testing and validation sets. Standardization, or z-scoring, is a popular way to improve network performance during training.
XV = [XTrain2{:}]; mu = mean(XV,2); sg = std(XV,[],2); XTrainSD = XTrain2; XTrainSD = cellfun(@(x)(x-mu)./sg,XTrainSD,'UniformOutput',false); XValidSD = XValid2; XValidSD = cellfun(@(x)(x-mu)./sg,XValidSD,'UniformOutput',false); XTestSD = XTest2; XTestSD = cellfun(@(x)(x-mu)./sg,XTestSD,'UniformOutput',false);
Show the means of the standardized instantaneous frequency and spectral entropy.
instFreqNSD = XTrainSD{1}(1,:); pentropyNSD = XTrainSD{1}(2,:); mean(instFreqNSD)
ans = 0.1544
mean(pentropyNSD)
ans = 0.1935
Modify LSTM Network Architecture
Now that the signals each have two dimensions, it is necessary to modify the network architecture by specifying the input sequence size as 2. Specify a bidirectional LSTM layer with an output size of 100, and output the last element of the sequence. Specify two classes by including a fully connected layer of size 2, followed by a softmax layer and a classification layer.
layers = [ ... sequenceInputLayer(2) bilstmLayer(50,'OutputMode','last') fullyConnectedLayer(2) softmaxLayer ]
layers = 4×1 Layer array with layers: 1 '' Sequence Input Sequence input with 2 dimensions 2 '' BiLSTM BiLSTM with 50 hidden units 3 '' Fully Connected 2 fully connected layer 4 '' Softmax softmax
Specify the training options. Set the maximum number of epochs to 100 to allow the network to make 100 passes through the training data.
options = trainingOptions('adam', ... 'MaxEpochs',150, ... 'MiniBatchSize', 200, ... 'GradientThreshold',1, ... 'Shuffle','every-epoch', ... 'InitialLearnRate', 1e-3, ... 'ExecutionEnvironment','gpu',... 'plots','training-progress', ... 'Metrics','accuracy', ... 'InputDataFormats','CTB', ... 'ValidationData',{XValidSD,YValid}, ... 'OutputNetwork','last-iteration', ... 'Verbose',false);
Train LSTM Network with Time-Frequency Features
Train the LSTM network with the specified training options and layer architecture by using trainnet
(Deep Learning Toolbox).
net2 = trainnet(XTrainSD,YTrain,layers,"crossentropy",options);
The time required for training decreases because the TF moments are shorter than the raw sequences.
Visualize Training and Testing Accuracy
Classify the training data using the updated LSTM network. Visualize the classification performance as a confusion matrix.
scores = minibatchpredict(net2,XTrainSD, ... InputDataFormats="CTB", ... ExecutionEnvironment="gpu"); trainPred2 = scores2label(scores,classNames); LSTMAccuracy = sum(trainPred2 == YTrain)/numel(YTrain)*100
LSTMAccuracy = 97.9666
figure confusionchart(YTrain,trainPred2,'ColumnSummary','column-normalized',... 'RowSummary','row-normalized','Title','Confusion Chart for LSTM');
Classify the testing data with the updated network. Plot the confusion matrix to examine the testing accuracy.
scores = minibatchpredict(net2,XTestSD, ... InputDataFormats="CTB", ... ExecutionEnvironment="gpu"); testPred2 = scores2label(scores,classNames); LSTMAccuracy = sum(testPred2 == YTest)/numel(YTest)*100
LSTMAccuracy = 95.7916
figure confusionchart(YTest,testPred2,'ColumnSummary','column-normalized',... 'RowSummary','row-normalized','Title','Confusion Chart for LSTM');
Conclusion
This example shows how to build a classifier to detect atrial fibrillation in ECG signals using an LSTM network. The procedure uses oversampling to avoid the classification bias that occurs when one tries to detect abnormal conditions in populations composed mainly of healthy patients. Training the LSTM network using raw signal data results in a poor classification accuracy. Training the network using two time-frequency-moment features for each signal significantly improves the classification performance and also decreases the training time.
References
[1] AF Classification from a Short Single Lead ECG Recording: the PhysioNet/Computing in Cardiology Challenge, 2017. https://physionet.org/challenge/2017/
[2] Clifford, Gari, Chengyu Liu, Benjamin Moody, Li-wei H. Lehman, Ikaro Silva, Qiao Li, Alistair Johnson, and Roger G. Mark. "AF Classification from a Short Single Lead ECG Recording: The PhysioNet Computing in Cardiology Challenge 2017." Computing in Cardiology (Rennes: IEEE). Vol. 44, 2017, pp. 1–4.
[3] Goldberger, A. L., L. A. N. Amaral, L. Glass, J. M. Hausdorff, P. Ch. Ivanov, R. G. Mark, J. E. Mietus, G. B. Moody, C.-K. Peng, and H. E. Stanley. "PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals". Circulation. Vol. 101, No. 23, 13 June 2000, pp. e215–e220. http://circ.ahajournals.org/content/101/23/e215.full
[4] Pons, Jordi, Thomas Lidy, and Xavier Serra. "Experimenting with Musically Motivated Convolutional Neural Networks". 14th International Workshop on Content-Based Multimedia Indexing (CBMI). June 2016.
[5] Wang, D. "Deep learning reinvents the hearing aid," IEEE Spectrum, Vol. 54, No. 3, March 2017, pp. 32–37. doi: 10.1109/MSPEC.2017.7864754.
[6] Brownlee, Jason. How to Scale Data for Long Short-Term Memory Networks in Python. 7 July 2017. https://machinelearningmastery.com/how-to-scale-data-for-long-short-term-memory-networks-in-python/.
See Also
Functions
instfreq
|pentropy
|trainingOptions
(Deep Learning Toolbox) |trainnet
(Deep Learning Toolbox) |bilstmLayer
(Deep Learning Toolbox) |lstmLayer
(Deep Learning Toolbox)
Objects
gpuArray
(Parallel Computing Toolbox)
Related Topics
- Long Short-Term Memory Neural Networks (Deep Learning Toolbox)