- Increase the amount of data augmentation: Data augmentation is a technique that artificially increases the size of your dataset by applying random transformations to the images during training. It helps in introducing variability in the data, making the model more robust to overfitting. You can try increasing the amount of data augmentation by adding more random transformations such as horizontal flipping, vertical flipping, and changing brightness/contrast.
- Use dropout regularization: Dropout is a regularization technique that randomly sets a fraction of the input units to 0 at each update during training, which helps in preventing the model from relying too heavily on certain features and encourages it to learn more generalized representations. You can add a dropout layer after the fully connected layer in your model by using the dropoutLayer function from MATLAB's Deep Learning Toolbox.
- Reduce the learning rate: A high learning rate can cause the model to overshoot the optimal weights during training, leading to overfitting. You can try reducing the initial learning rate in your trainingOptions function, for example, by setting it to a lower value such as 1e-4 or 1e-5.
- Use early stopping: Early stopping is a technique that monitors the validation loss during training and stops the training process if the validation loss starts to increase, indicating overfitting. You can add the EarlyStopping option in your trainingOptions function and set it to a reasonable value, such as 5 or 10, to stop training early if needed.
- Add more training data: Overfitting can occur when the model is not exposed to enough diverse training data. You can consider increasing the size of your training dataset by collecting more data, or by using data augmentation techniques to generate synthetic data.
- Try using a smaller model: ResNet-18 is a relatively deep model with a large number of parameters, which can make it more prone to overfitting, especially when the training dataset is small. You can try using a smaller CNN architecture, such as ResNet-9 or a custom architecture with fewer layers, to see if it helps in reducing overfitting.
- Regularize the fully connected layers: You can add weight regularization techniques, such as L1 or L2 regularization, to the fully connected layers in your model to prevent overfitting. You can use the fullyConnectedLayer function's WeightRegularization and BiasRegularization options to specify the type and strength of regularization to apply.
Overfitting deep neural network
21 visualizzazioni (ultimi 30 giorni)
Mostra commenti meno recenti
I am using CNN architecture resnet18 with transfer learning for classifications. Overfitting is heppenrd after trainging and testing the model.
Here is my code. Can anyone please tell me what chanfes I have to do in the below code. Please see the attached result file in which you can see the data overfitting is happening.
clear all
close all
imds = imageDatastore("D:\DatasetJPG", ...
'IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7); %70% for train 30% for test
net=resnet18; % for the first time,you have to download the package from Add-on explorer
%Replace Final Layers
numClasses = numel(categories(imdsTrain.Labels));
lgraph = layerGraph(net);
newFCLayer = fullyConnectedLayer(numClasses,'Name','new_fc','WeightLearnRateFactor',10,'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,'fc1000' ,newFCLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'ClassificationLayer_predictions',newClassLayer);
%Train Network
inputSize = net.Layers(1).InputSize;
imageAugmenter = imageDataAugmenter( ...
'RandRotation',[-5,5], ...
'RandXTranslation',[-5 5], ...
'RandYTranslation',[-5 5]);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',20, ...
'InitialLearnRate',1e-3, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',5, ...
'Verbose',false, ...
'Plots','training-progress');
trainedNet = trainNetwork(augimdsTrain,lgraph,options);
YPred = classify(trainedNet,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
C = confusionmat(imdsValidation.Labels,YPred)
cm = confusionchart(imdsValidation.Labels,YPred);
cm.Title = 'Confusion Matrix for Validation Data';
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';
0 Commenti
Risposta accettata
Sugandhi
il 28 Apr 2023
Modificato: Sugandhi
il 28 Apr 2023
Hi Muhammad,
I understand that you are using CNN architecture resnet18 with transfer learning for classifications. Overfitting is happened after trainging and testing the model.
Based on the code you provided, here are some workarounds to address the issue of overfitting in your ResNet-18 CNN model:
Implementing these changes can help in reducing overfitting in your ResNet-18 model and improving its generalization performance.
Più risposte (0)
Vedere anche
Categorie
Scopri di più su Deep Learning Toolbox in Help Center e File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!