Azzera filtri
Azzera filtri

mnist classification using batch method

2 visualizzazioni (ultimi 30 giorni)
nadia
nadia il 29 Nov 2015
Modificato: Greg Heath il 5 Dic 2015
Hi. I want to train a neural network with mnist database using batch method. I use below code but my accuracy is very low. but I think the code is correct. can any one help me please?
function [hiddenWeights, outputWeights, error] = train_network_batch(numberOfHiddenUnits, input, target, epochs, batchSize, learningRate,lambda)
% The number of training vectors.
trainingSetSize = size(input, 2);
% Input vector has 784 dimensions.
inputDimensions = size(input, 1);
% We have to distinguish 10 digits.
outputDimensions = size(target, 1);
% Initialize the weights for the hidden layer and the output layer.
% hiddenWeights = randn(NHiddenUnit, inputDimensions)*1/sqrt(size(input, 1));
% outputWeights = randn(outputDimensions, NHiddenUnit)*1/sqrt(size(input, 1));
hiddenWeights = rand(numberOfHiddenUnits, inputDimensions);
outputWeights = rand(outputDimensions, numberOfHiddenUnits);
hiddenWeights = hiddenWeights./size(hiddenWeights, 2);
outputWeights = outputWeights./size(outputWeights, 2);
hiddenWeights_store = hiddenWeights;
outputWeights_store = outputWeights;
n = zeros(batchSize,1);
validation_count=0;
validation_accuracy=0;
figure; hold on;
%batch method
for t = 1: epochs
for k = 1: batchSize
% Select which input vector to train on.
n(k) = floor(rand(1)*trainingSetSize + 1);
% n(k) =k;
% Propagate the input vector through the network.
inputVector = input(:, n(k));
hiddenActualInput = hiddenWeights*inputVector;
hiddenOutputVector = linear_func(hiddenActualInput);
outputActualInput = outputWeights*hiddenOutputVector;
outputVector = linear_func(outputActualInput);
targetVector = target(:, n(k));
% Backpropagate the errors.
outputDelta = dlinear_func(outputActualInput).*(outputVector - targetVector);
hiddenDelta = dlinear_func(hiddenActualInput).*(outputWeights'*outputDelta);
% outputWeights_store = outputWeights_store -(learningRate*lambda/batchSize).*outputWeights- learningRate.*outputDelta*hiddenOutputVector'; hiddenWeights_store = hiddenWeights_store -(learningRate*lambda/batchSize).*hiddenWeights-learningRate.*hiddenDelta*inputVector';
% outputWeights =(1-(learningRate*lambda/batchSize)).*outputWeights - learningRate.*outputDelta*hiddenOutputVector'; % hiddenWeights = (1-(learningRate*lambda/batchSize)).*hiddenWeights - learningRate.*hiddenDelta*inputVector';
end;
outputWeights=outputWeights+(outputWeights_store./batchSize);
hiddenWeights=hiddenWeights+(hiddenWeights_store./batchSize);
outputWeights_store=0;
hiddenWeights_store=0;
% %*********************************end of batch method*************** % Calculate the error for plotting. error = 0; for k = 1: batchSize inputVector = input(:, n(k)); targetVector = target(:, n(k));
error = error + norm(linear_func(outputWeights*linear_func(hiddenWeights*inputVector)) - targetVector, 2);
end;
error = error/batchSize;
plot(t, error,'*');
title(['MSE_ batch','NH= ',num2str(numberOfHiddenUnits),',',' alfa=',num2str(learningRate),' ,epoch=',num2str(epochs)]);
xlabel('epoch');
ylabel('cost');
inputValues=load('validation.mat');
inputValues=inputValues.v;
labels=load('label.mat');
labels=labels.l;
[correctlyClassified, classificationErrors]=validation_network(hiddenWeights,outputWeights,inputValues',labels);
correctlyClassified=correctlyClassified/10000;
if correctlyClassified<= validation_accuracy
validation_count=validation_count+1;
else
validation_count=0;
end
if validation_count>7
break;
end
validation_accuracy=correctlyClassified;
end;
end

Risposta accettata

Greg Heath
Greg Heath il 5 Dic 2015
Modificato: Greg Heath il 5 Dic 2015
1. I don't think that anyone wants to wade through all of that code when you can just use MATLAB classification functions
help PATTERNNET
doc PATTERNNET
2. If none of your hidden or output functions is nonlinear, then all you have is a complicated linear classifier which can be implemented with BACKSLASH.
Hope this helps.
Thank you for formally accepting my answer
Greg

Più risposte (0)

Tag

Non è stata ancora inserito alcun tag.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by