Deep learning: predict and forward give very inconsistent result with batch normalisation
Mostra commenti meno recenti
Hi there,
I have a dnn with BatchNormalization layers. I did [Uf, statef] = forward(dnn_net, XT); "statef" returned from forward call should contain the learned mean and variance of the BN layers. I then update the dnn_net.State = statef so that the dnn_net.State is updated with the learned mean and var from the forward call. Then I did Up = predict(dnn_net,XT) [on the same data XT as the forward call]. Then I compare the mse(Up,Uf) and it turns out to be quite large (e.g 2.xx). This is totally not expected. Please help
I have created a simple test script to show the problem below:
numLayer = 9;
numNeurons = 80;
layers = featureInputLayer(2);% 2 input features
for i = 1:numLayer-1
layers = [
layers
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
geluLayer
];
end
layers = [
layers
fullyConnectedLayer(4)];
dnn_net = dlnetwork(layers,'Initialize',true);
dnn_net = expandLayers(dnn_net);% expand the layers in residual blocks
dnn_net = dlupdate(@double,dnn_net);
X = 2*pi*dlarray(randn(1,1000),"CB");
T = 2*pi*dlarray(randn(1,1000),"CB");
XT = cat(1,X,T);
% compute DNN output using forward function, statef contains
% batchnormlayers (learned means and learned var)
[Uf,statef] = forward(dnn_net,XT);
dnn_net.State = statef;% update the dnn_net.State (so that it has the BN layers updated learned means and var
Up = predict(dnn_net,XT);
% DNN output between predict call and forward call shoudl be the same in
% this case (becausae the dnn_net.State is updated with the same learned mean/var from forward calls.
% However, it is not the case. The err is quite large
plot(Uf(1,:),'r-');hold on; plot(Up(1,:),'b-');
err = mse(Uf(1,:),Up(1,:))
Risposte (1)
To get agreement, you need to allow the trained mean and variance to converge:
numLayer = 9;
numNeurons = 80;
layers = featureInputLayer(2);% 2 input features
for i = 1:numLayer-1
layers = [
layers
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
geluLayer
];
end
layers = [
layers
fullyConnectedLayer(4)];
dnn_net = dlnetwork(layers,'Initialize',true);
dnn_net = expandLayers(dnn_net);% expand the layers in residual blocks
dnn_net = dlupdate(@double,dnn_net);
X = 2*pi*dlarray(randn(1,1000),"CB");
T = 2*pi*dlarray(randn(1,1000),"CB");
XT = cat(1,X,T);
% compute DNN output using forward function, statef contains
% batchnormlayers (learned means and learned var)
for i=1:500
[Uf,statef] = forward(dnn_net,XT);
dnn_net.State = statef;% update the dnn_net.State (so that it has the BN layers updated learned means and var
end
Up = predict(dnn_net,XT);
% DNN output between predict call and forward call shoudl be the same in
% this case (becausae the dnn_net.State is updated with the same learned mean/var from forward calls.
% However, it is not the case. The err is quite large
plot(Uf(1,:),'r-');hold on; plot(Up(1,:),'--b');
err = mse(Uf(1,:),Up(1,:))
6 Commenti
Matt J
il 25 Ago 2024
Did you run my code above, which demonstrates that convergence does occur? How is the result you are getting different from what is shown there?
No, they are not expected to have consistent gradients. When using forward(), the batch normalization uses the mean and variance of activation data from previous layers. These are considered moving values, i..e., functions of the learnable parameters, by dlgradient.
When using predict(), however, the normalization uses finalized population statistics for the mean and variance, which are treated as constants by dlgradient().
Tom
il 26 Ago 2024
Categorie
Scopri di più su Install Products in Centro assistenza e File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!

