Azzera filtri
Azzera filtri

Precision-Recall curve for Multiclass classification

34 visualizzazioni (ultimi 30 giorni)
Laraib
Laraib il 28 Lug 2023
Modificato: Drew il 24 Ago 2023
I have been trying hard to find any document or example related to ploting precision-recall curve for multiclass classification. But it seems like there is no way to do that. How would I make a precision-recall curve for my model.
Following is the code I use to get confusion matrix
fpath = 'E:\Research Data\Four Classes';
testData = fullfile(fpath, 'Test');
% %
testDatastore = imageDatastore(testData, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
allclass = [];
for i = 1:length(testDatastore.Labels)
I = readimage(testDatastore, i);
class = classify(incep, I);
allclass = [allclass class];
end
predicted = allclass';
% figure
% plotconfusion(testDatastore.Labels, predicted)
  3 Commenti
Laraib
Laraib il 28 Lug 2023
What you suggested was extremely helpful. I appreciate your guidance.
I need to go with one class at a time as positive class while others being negative class.
Is there any document or example I can look into?
I have balanced number of images in each class and each class is important (referencing the article).
Any matlab resource would help me.
Image Analyst
Image Analyst il 29 Lug 2023
help roc
ROC Receiver operating characteristic. The receiver operating characteristic is a metric used to check the quality of classifiers. For each class of a classifier, threshold values across the interval [0,1] are applied to outputs. For each threshold, two values are calculated, the True Positive Ratio (the proportion of the targets that are greater than or equal to the threshold that actually have a target value of one), and the False Positive Ratio (the proportion of the targets that are greater than or equal to the threshold that actually have a target value of zero). ROC does not support categorical targets. To compute ROC metrics for categorical targets, use ROCMETRICS. For single class problems, [TPR,FPR,TH] = roc(T,Y) takes a 1xQ target matrix T, where each element is either 1 or 0 indicating class membership or non-menbership respectively, and 1xQ outputs Y of values in the range [0,1]. It returns three 1xQ vectors: the true-positive/positive ratios TPR, the false-positive/negative ratios FPR, and the thresholds associated with each of those values TH. For multi-class problems [TPR,FPR,TH] = roc(T,Y) takes an SxQ target matrix T, where each column contains a single 1 value, with all other elements 0. The row index of each 1 indicates which of S categories that vector represents. It also takes an SxQ output matrix Y, with values in the range [0,1]. The row indices of the largest elements in each column of Y indicate the most likely class. In the multi-class case, all three values returned are 1xS cell arrays, so that TPR{i}, FPR{i} and TH{i} are the ratios and thresholds for the ith class. roc(T,Y) can also take a boolean row vector T, and row vector Y, in which case two categories are represented by targets 1 and 0. Here a network is trained to recognize iris flowers the ROC is calculated and plotted. [x,t] = iris_dataset; net = patternnet(10); net = train(net,x,t); y = net(x); [tpr,fpr,th] = roc(t,y) plotroc(t,y) See also ROCMETRICS, PLOTROC, CONFUSION Documentation for roc doc roc

Accedi per commentare.

Risposte (1)

Drew
Drew il 23 Ago 2023
Modificato: Drew il 24 Ago 2023
Given a multiclass classification problem, you can create a Precision-Recall curve for each class by considering the one-vs-all binary classification problem for each class. The Precision-Recall curves can be built with:
The documentation for the plot method of the rocmetrics object https://www.mathworks.com/help/stats/rocmetrics.plot.html has another Precision-Recall curve example: openExample('stats/PlotOtherPerformanceCurveExample')
This comment applies to ROC curves, but not Precision-Recall curves: For multiclass problems, the "plot" method of the rocmetrics object also has the ability to create ROC curves from averaged metrics using the "AverageROCType" Name-Value Argument, and the "average" method of the rocmetrics object can be used to calculate these average metrics https://www.mathworks.com/help/stats/rocmetrics.average.html . An example of an average ROC curve is here: openExample('stats/PlotAverageROCCurveExample'). The averaging options include micro, macro, and weighted-macro.
In order to build the rocmetrics object with the rocmetrics function, or to use the perfcurve function, you will need the scores from the classify function.
Here is a precision-recall curve example for a tree model built with fisheriris data.
t=readtable("fisheriris.csv");
response="Species";
% Create set of models for 5-fold cross-validation
cvmdl=fitctree(t,response,KFold=5);
% Get cross-validation predictions and scores
[yfit,scores]=kfoldPredict(cvmdl);
% View confusion matrix
% The per-class Precision can be seen in the blue boxes in the
% column-summary along the bottom.
% The per-class Recall can be seen in the blue boxes in the row summary
% along the right side.
cm=confusionchart(t{:,response},yfit);
cm.ColumnSummary='column-normalized';
cm.RowSummary='row-normalized';
% Calculate precision, recall, and F1 per-class from the raw confusion
% matrix counts
% Precision = TP/(TP+FP); Recall = TP/(TP+FN);
% F1score is the harmonic mean of Precision and Recall.
% The cm.Normalization needs to be set to 'absolute', so that the values
% are raw counts.
counts = cm.NormalizedValues;
precisionPerClass= diag(counts)./ (sum(counts))';
recallPerClass = diag(counts)./ (sum(counts,2));
F1PerClass=2.*diag(counts) ./ ((sum(counts,2)) + (sum(counts))');
% Create rocmetrics object
% Add the metric "PositivePredictiveValue", which is Precision.
% The metric "TruePositiveRate", which is Recall, is in the Metrics by default.
rocObj=rocmetrics(t{:,response}, scores, cvmdl.ClassNames, ...
AdditionalMetrics="PositivePredictiveValue");
% For illustration, focus on metrics for one class, virginica
classindex=3;
% Plot the precision-recall curve for one class, given by classindex
% (By default, the rocmetrics plot function will plot the one-vs-all PR curves
% for all of the classes at once.)
r=plot(rocObj, YAxisMetric="PositivePredictiveValue", ...
XAxisMetric="TruePositiveRate", ...
ClassNames=rocObj.ClassNames(classindex));
hold on;
xlabel("Recall");ylabel("Precision"); title('Precision-Recall Curve');
% Place the operating point on the figure
scatter(recallPerClass(classindex),precisionPerClass(classindex),[],r.Color,"filled");
% Update legend
legend(strcat(string(rocObj.ClassNames(classindex))," one-vs-all P-R curve"), ...
strcat(string(rocObj.ClassNames(classindex)), ...
sprintf(' one-vs-all operating point\nP=%4.1f%%, R=%4.1f%%, F1=%4.1f%%', ...
100*precisionPerClass(classindex),100*recallPerClass(classindex), ...
100*F1PerClass(classindex))));
hold off;

Prodotti


Release

R2023a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by