How can I obtain the Shapley values from a Long Short Term Mermory network?

10 visualizzazioni (ultimi 30 giorni)
I have created a LSTM neural network and did regression analysis. I want to calculate its Shapley value by executing this code:
layers = [ ...
sequenceInputLayer(numFeatures)
lstmLayer(numHiddenUnits)
fullyConnectedLayer(numResponses)
regressionLayer];
net = trainNetwork(Xtrain,Ytrain,layers,options);
blackbox = @(x)predict(net,Xtrain);
explainer = shapley(blackbox,Xtrain);
I get the following error:
validateattributes(out,{'double','single'}, {'column','nonempty'},
mfilename,getString(message('stats:shapley:FunctionHandleOutput')));

Risposte (1)

Ahmadreza
Ahmadreza il 27 Gen 2023
I think that the LSTM network is not compatible with the Shapley function. Currently, only the following models are supported:
Regression Model Object: Ensemble of regression models, Gaussian kernel regression model using random feature expansion, Gaussian process regression, Generalized additive model, Linear regression for high-dimensional data, Neural
Classification Model Object: Discriminant analysis classifier, Multiclass model for support vector machines or other classifiers, Ensemble of learners for classification, Gaussian kernel classification model using random feature expansion, Generalized additive model, k-nearest neighbor classifier, Linear classification model, Multiclass naive Bayes model, Neural network classifier, Support vector machine classifier for one-class and binary classification, Binary decision tree for multiclass classification.
https://www.mathworks.com/help/stats/shapley.html#mw_c2327b12-104d-48ef-8a71-1f0e8769549b

Categorie

Scopri di più su Dimensionality Reduction and Feature Extraction in Help Center e File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by