Azzera filtri
Azzera filtri

How to plot confusion matrix?

11 visualizzazioni (ultimi 30 giorni)
Adrian Kleffler
Adrian Kleffler il 22 Mag 2023
Modificato: Venkat Siddarth il 29 Mag 2023
Hello guys, I want to plot confusion matrix after training an object detector ... Here is my code ... How to plot confusion matrix ?
data = load("letisko_labels_new.mat");
LabelData = data.gTruth.LabelData;
LabelData.imageFilename = fullfile(LabelData.imageFilename);
rng("default");
shuffledIndices = randperm(height(LabelData));
idx = floor(0.6 * length(shuffledIndices) );
trainingIdx = 1:idx;
trainingDataTbl = LabelData(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = LabelData(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = LabelData(shuffledIndices(testIdx),:);
imdsTrain = imageDatastore(trainingDataTbl{:,"imageFilename"});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,2:6));
imdsValidation = imageDatastore(validationDataTbl{:,"imageFilename"});
bldsValidation = boxLabelDatastore(validationDataTbl(:,2:6));
imdsTest = imageDatastore(testDataTbl{:,"imageFilename"});
bldsTest = boxLabelDatastore(testDataTbl(:,2:6));
trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
validateInputData(trainingData);
validateInputData(validationData);
validateInputData(testData);
inputSize = [256 256 3];
className = ["kamera","lietadlo","satelit","stlp","veza"];
rng("default")
trainingDataForEstimation = transform(trainingData,@(data)preprocessData(data,inputSize));
numAnchors = 9;
[anchors,meanIoU] = estimateAnchorBoxes(trainingDataForEstimation,numAnchors);
area = anchors(:, 1).*anchors(:,2);
[~,idx] = sort(area,"descend");
anchors = anchors(idx,:);
anchorBoxes = {anchors(1:3,:)
anchors(4:6,:)
anchors(7:9,:)
};
detector = yolov4ObjectDetector("csp-darknet53-coco",className,anchorBoxes,InputSize=inputSize);
augmentedTrainingData = transform(trainingData,@augmentData);
options = trainingOptions("adam",...
GradientDecayFactor=0.9,...
SquaredGradientDecayFactor=0.999,...
InitialLearnRate=0.001,...
LearnRateSchedule="none",...
MiniBatchSize=4,...
L2Regularization=0.0005,...
MaxEpochs=50,...
BatchNormalizationStatistics="moving",...
DispatchInBackground=true,...
ResetInputNormalization=false,...
Shuffle="every-epoch",...
VerboseFrequency=20,...
ValidationFrequency=1000,...
Plots="training-progress",...
CheckpointPath='C:\BAKALARKA\checkpointYOLO',...
ValidationData=validationData);
doTraining = true;
if doTraining
% Train the YOLO v4 detector.
[detector,info] = trainYOLOv4ObjectDetector(augmentedTrainingData,detector,options);
else
% Load pretrained detector for the example.
detector = downloadPretrainedYOLOv4Detector();
end
detectionResults = detect(detector,testData,'MiniBatchSize',4);
[ap,recall,precision] = evaluateDetectionPrecision(detectionResults,testData);
recallv = cell2mat(recall);
precisionv = cell2mat(precision);
[r,index] = sort(recallv);
p = precisionv(index);
figure
plot(r,p)
xlabel("Recall")
ylabel("Precision")
grid on
title(sprintf("Average Precision = %.2f",mean(ap)))

Risposte (1)

Venkat Siddarth
Venkat Siddarth il 29 Mag 2023
Modificato: Venkat Siddarth il 29 Mag 2023
I understand that you are looking to plot confusion matrix for the model. Here I am assuming that you want to plot the confusion matrix for the labels column in detectionResults,which can be achieved by using a function called confusionmat. This function takes two vectors as inputs, the true labels and the predicted labels and produces a confusion matrix.
y_true=[1 0 1 1 1 1 0 0];
y_pred=[1 1 1 1 0 0 1 1];
C=confusionmat(y_true,y_pred)
C = 2×2
0 3 2 3
After generating the confusion matrix you can plot the confusion matrix using the function confusionchart
confusionchart(C)
To know more about these functions, check out the following documentation
I hope this resolves the issue,
Regards
Venkat Siddarth V.

Categorie

Scopri di più su Recognition, Object Detection, and Semantic Segmentation in Help Center e File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by