5-fold cross validation with neural networks (function approximation)
Mostra commenti meno recenti
I have matlab code which implement hold out cross validation (attached). I am looking for help to perform 5-fold cross validation on the same model architecture. Please help me to figure this out. Thank you.
%% Data
X = x'; % input Always stays same
Y = yte'; % target
%% Model parameter change
% Choose a Training Function ('trainlm', 'trainscg','traingdx')
trainFcn = 'trainlm';
% Choose a Neuron in hidden layers
hiddenLayerSize =17;
% Choose an activation fucntion ( logsig,tansig, purelin)
net.layers{1}.transferFcn = 'logsig'; % hidden layer
net.layers{2}.transferFcn = 'poslin'; % output layer
% Choose an evaluation metrics (mae, mse)
net.performFcn = 'mse';
net.plotFcns = {'plotperform','plottrainstate','ploterrhist', 'plotregression', 'plotfit'};
% view network
net = fitnet(hiddenLayerSize,trainFcn);
%view(net)
%% Data-processing
net.input.processFcns = {'removeconstantrows','mapstd'}; % Input: remove const values and map values between [0 to 1]
net.output.processFcns = {'removeconstantrows','mapstd'}; % Input: remove const values and map values between [0 to 1]
%% Data split (0.7,0.15 & 0.15)
net.divideFcn = 'dividerand'; % randonmly
net.divideMode = 'sample'; % each obs as sample
net.divideParam.trainRatio = 70/100; % train
net.divideParam.valRatio = 15/100; % test
net.divideParam.testRatio = 15/100; % validation
%% Train a neural network
[net,tr] = train(net,X,Y);
% net- gives train model
% tr-training records
%% network performance
figure(1), plotperform(tr) % Plot network performance
figure(2), plottrainstate(tr) % Plot training state values.
%% Error and R2
Ytest = net(X); % prediction on X
e = gsubtract(Y,Ytest); % subtraction( Yactual-ypred)
MSE = perform(net, Y,Ytest); % Calculate network performance = mae or mse value
MAE=mae(net, Y,Ytest);
%% Regression performance
trOut = Ytest(tr.trainInd); %traing output-predicted
trTarg = Y(tr.trainInd); % training target-Actual
vOut = Ytest(tr.valInd); % val output
vTarg = Y(tr.valInd); % val target
tsOut = Ytest(tr.testInd); % test output
tsTarg = Y(tr.testInd); %test target
figure(4), plotregression(trTarg, trOut, 'Train', vTarg, vOut, 'Validation', tsTarg, tsOut, 'Testing',Y,Ytest,'All')
% R2
R2_Train= regression(trTarg, trOut)^2;
R2_Val= regression(vTarg, vOut)^2;
R2_Test= regression(tsTarg, tsOut)^2;
R2_all= regression(Y,Ytest)^2;
%figure(3), ploterrhist(e) % Plot error histogram
1 Commento
Chetan Badgujar
il 5 Mar 2021
Risposta accettata
Più risposte (0)
Categorie
Scopri di più su Define Shallow Neural Network Architectures in Centro assistenza e File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!