implementation of mini-batch stochastic gradient descent

20 visualizzazioni (ultimi 30 giorni)
I implemented a mini-batch stochastic gradien descent but counldn't find the bug in my code.
I used this implement to do a classification problem but all my final predictions are 0.
W2 = -1+2*rand(5,2); W3 = -1+2*rand(5,5);
W4 = -1+2*rand(5,5); W5 = -1+2*rand(1,5);
b2 = -1+2*rand(5,1); b3 = -1+2*rand(5,1);
b4 = -1+2*rand(5,1); b5 = -1+2*rand(1,1);
eta = 5e-3; % learning rate
iter = 1000; % number of iterations
num_data = length(label);
loss_vec = zeros(1,iter);
tloss_vec = zeros(1,iter);
for it = 1:iter
% mini-batch method
batch_size = 50;
rand_idx = randperm(num_data);
rand_idx = reshape(rand_idx,[],num_data/batch_size);
for idx = rand_idx
% forward pass
a2 = activate([x1(:,idx);x2(:,idx)], W2, b2);
a3 = activate(a2,W3,b3);
a4 = activate(a3,W4,b4);
a5 = activate(a4,W5,b5);
% backward pass (gradient)
delta5 = a5.*(1-a5).*(a5-label(idx));
delta4 = a4.*(1-a4).*(W5'*delta5);
delta3 = a3.*(1-a3).*(W4'*delta4);
delta2 = a2.*(1-a2).*(W3'*delta3);
% update weights and bias
W2 = W2 - 1/length(idx)*eta*delta2*[x1(:,idx);x2(:,idx)]';
W3 = W3 - 1/length(idx)*eta*delta3*a2';
W4 = W4 - 1/length(idx)*eta*delta4*a3';
W5 = W5 - 1/length(idx)*eta*delta5*a4';
b2 = b2 - 1/length(idx)*eta*sum(delta2,2);
b3 = b3 - 1/length(idx)*eta*sum(delta3,2);
b4 = b4 - 1/length(idx)*eta*sum(delta4,2);
b5 = b5 - 1/length(idx)*eta*sum(delta5,2);
% compute train loss and test loss
loss_vec(it) = 1/(2*num_data)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[x1;x2],label);
tloss_vec(it) = 1/(2*200)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[tx1;tx2],tlabel);
end
end
%% cost function
function loss = LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,x,y)
a2 = activate(x, W2, b2);
a3 = activate(a2, W3, b3);
a4 = activate(a3, W4, b4);
a5 = activate(a4, W5, b5);
loss = norm(a5-y,2)^2;
end
%% prediction
function pred = predict(W2,W3,W4,W5,b2,b3,b4,b5,x)
a2 = activate(x, W2, b2);
a3 = activate(a2, W3, b3);
a4 = activate(a3, W4, b4);
a5 = activate(a4, W5, b5);
pred = round(a5);
end
%% activation function
function y = activate(x,W,b)
y = 1./(1+exp(-(W*x+b)));
end

Risposte (2)

Mahesh Taparia
Mahesh Taparia il 2 Apr 2021
Hi
You mentioned that you are implementing a classification network. In your code, you are using square of L2 norm to calculate the loss and loss derivative is also not correct while doing back propagation. Moreover, since it is a classification network, use the classification loss like cross entropy loss, focalcrossentropy, etc instead of norm. May be this is the reason you are getting 0 everytime.
Also, you can use MATLAB inbuilt function to perform back propagation. For this, you can refer the link given below:
Hope it will help!
  1 Commento
konoha
konoha il 2 Apr 2021
Modificato: konoha il 2 Apr 2021
the derivative of mes is -(y-f(x))f'(x). I don't follow your suggestions.
Thank you.

Accedi per commentare.


Mohamed Salem
Mohamed Salem il 22 Dic 2022
Write a MATLAB code, that implement Dalta learning rule with mini-batch.
Compare (with graph) your mini-batch algorithm with SGD, Batch algorithm in terms of mean square error.

Prodotti


Release

R2020b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by