Azzera filtri
Azzera filtri

Model Selection using Hold-out validation

5 visualizzazioni (ultimi 30 giorni)
Tan
Tan il 30 Apr 2024
Risposto: Shubham il 12 Mag 2024
Project overview: Total 1200 images for three classes and need to separate into training, validation and test set with the ratio of 70:20:10. And need to perform model selection and evaluation.
Does the coding provide corrected? Below is the coding:
clc
clear all;
close all;
% Open file in MATLAB
outputFolder = fullfile('D:\recycle101');
rootFolder = fullfile(outputFolder, 'project');
categories = {'aluminiumcan', 'petbottle', 'drinkcartonbox'};
% Load dataset
imds = imageDatastore(fullfile(rootFolder, categories), 'LabelSource', 'foldernames');
% Determine the number of images
numImages = numel(imds.Files);
% Define ratios for splitting the dataset
trainRatio = 0.7;
valRatio = 0.2;
testRatio = 0.1;
% Calculate the number of images for each set
numTrain = round(trainRatio * numImages);
numVal = round(valRatio * numImages);
numTest = numImages - numTrain - numVal;
% Shuffle the dataset
imds = shuffle(imds);
% Split the dataset into training, validation, and test sets
imdsTrain = subset(imds, 1:numTrain);
imdsValidation = subset(imds, numTrain+1:numTrain+numVal);
imdsTest = subset(imds, numTrain+numVal+1:numTrain+numVal+numTest);
% Display dataset counts
disp('Training Set:');
tblTrain = countEachLabel(imdsTrain)
disp('Validation Set:');
tblValidation = countEachLabel(imdsValidation)
disp('Test Set:');
tblTest = countEachLabel(imdsTest)
% Randomly choose images for display
AluminiumCan = randi(numel(imdsTrain.Files));
PETBottles = randi(numel(imdsTrain.Files));
DrinkCartonBox = randi(numel(imdsTrain.Files));
% Plot randomly chosen images
figure
subplot(2,2,1);
imshow(readimage(imdsTrain, AluminiumCan));
title('Aluminium Can');
subplot(2,2,2);
imshow(readimage(imdsTrain, PETBottles));
title('PET Bottle');
subplot(2,2,3);
imshow(readimage(imdsTrain, DrinkCartonBox));
title('Drink Carton Box');
% Load pre-trained network
net = googlenet;
analyzeNetwork(net)
lys = net.Layers;
lys(end-3:end)
numClasses = numel(categories(imdsTrain.Labels));
lgraph = layerGraph(net);
% Replace the classification layers for the new task
newFCLayer = fullyConnectedLayer(3, 'Name', 'new_fc', 'WeightLearnRateFactor', 10, 'BiasLearnRateFactor', 10);
lgraphNew = replaceLayer(lgraph, 'loss3-classifier', newFCLayer);
newClassLayer = classificationLayer('Name', 'new_classoutput');
lgraphNew = replaceLayer(lgraphNew, 'output', newClassLayer);
analyzeNetwork(lgraphNew)
% Resize the images and train the network
imageSize = net.Layers(1).InputSize;
augmentedTrainingSet = augmentedImageDatastore(imageSize, imdsTrain, 'ColorPreprocessing', 'gray2rgb');
augmentedValidateSet = augmentedImageDatastore(imageSize, imdsValidation, 'ColorPreprocessing', 'gray2rgb');
options = trainingOptions('sgdm', ...
'MiniBatchSize', 4, ...
'MaxEpochs', 8, ...
'InitialLearnRate', 1e-4, ...
'Shuffle', 'every-epoch', ...
'ValidationData', augmentedValidateSet, ...
'ValidationFrequency', 3, ...
'Verbose', false, ...
'ExecutionEnvironment', 'cpu', ...
'Plots', 'training-progress');
trainedNet = trainNetwork(augmentedTrainingSet, lgraphNew, options);
% Classify test set
YPred = classify(trainedNet, imdsTest);
YTest = imdsTest.Labels;
% Calculate performance metrics
accuracy = sum(YPred == YTest) / numel(YTest);
confMat = confusionmat(YTest, YPred);
precision = diag(confMat)' ./ sum(confMat, 1);
recall = diag(confMat)' ./ sum(confMat, 2);
f1 = 2 * (precision .* recall) ./ (precision + recall);
overallPrecision = mean(precision);
overallRecall = mean(recall);
overallF1 = mean(f1);
% Display performance metrics
disp(['Accuracy: ', num2str(accuracy)]);
disp('Confusion Matrix:');
disp(confMat);
disp('Precision:');
disp(precision);
disp(['Overall Precision: ', num2str(overallPrecision)]);
disp('Recall:');
disp(recall);
disp(['Overall Recall: ', num2str(overallRecall)]);
disp(['F1 Score: ', num2str(overallF1)]);
% Plot confusion matrix
categories = {'aluminiumcan', 'petbottle', 'drinkcartonbox'};
label = categorical(categories);
cm = confusionchart(confMat, label);
cm.RowSummary = 'row-normalized';
cm.ColumnSummary = 'column-normalized';
% Save network
save(fullfile(outputFolder, 'simpleDL.mat'), 'trainedNet', 'lgraph');
% Testing Process
I = imread('drinkcartonbox2.JPEG');
ds = augmentedImageDatastore(imageSize, I, 'ColorPreprocessing', 'gray2rgb');
predictedLabel = classify(trainedNet, ds);
disp(['The loaded image belongs to ', char(predictedLabel), ' class']);

Risposte (1)

Shubham
Shubham il 12 Mag 2024
Hey Tan,
The provided code seems correct, however I have a few suggestions. I am assuming you do not have to do any data preprocessing.
Although Hold-out Validation has its advantages for e.g. it is simple and does the job quickly, for a small dataset you should consider a K-Fold Cross-Validation to get more robust evaluation and reduce bias in the model.
Ensure that the dataset is balanced across different classes after splitting. The randomness in "shuffle" might not guarantee an even distribution of classes in each subset. Although you are displaying the count of labels in each dataset, I would suggest you to use some sampling techniques that handle the evenly distribution of data. You could use the "cvpartition" function for creating the partitions using the "stratify" option.
c = cvpartition(group,"Holdout",p,"Stratify",stratifyOption)
You can refer to the following documentation here: https://www.mathworks.com/help/stats/cvpartition.html
The model preparation and evaluation is done correctly. You could try data augmentation techniques for further improving your model. Please refer to the following:
You may also enhance the model training my using multiple parallel workers with GPUs. You may find this documentation useful: https://www.mathworks.com/help/deeplearning/ug/deep-learning-with-matlab-on-multiple-gpus.html
I hope this helps!

Prodotti


Release

R2022b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by