Deep learning: predict and forward give very inconsistent result with batch normalisation

Hi there,
I have a dnn with BatchNormalization layers. I did [Uf, statef] = forward(dnn_net, XT); "statef" returned from forward call should contain the learned mean and variance of the BN layers. I then update the dnn_net.State = statef so that the dnn_net.State is updated with the learned mean and var from the forward call. Then I did Up = predict(dnn_net,XT) [on the same data XT as the forward call]. Then I compare the mse(Up,Uf) and it turns out to be quite large (e.g 2.xx). This is totally not expected. Please help
I have created a simple test script to show the problem below:
numLayer = 9;
numNeurons = 80;
layers = featureInputLayer(2);% 2 input features
for i = 1:numLayer-1
layers = [
layers
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
geluLayer
];
end
layers = [
layers
fullyConnectedLayer(4)];
dnn_net = dlnetwork(layers,'Initialize',true);
dnn_net = expandLayers(dnn_net);% expand the layers in residual blocks
dnn_net = dlupdate(@double,dnn_net);
X = 2*pi*dlarray(randn(1,1000),"CB");
T = 2*pi*dlarray(randn(1,1000),"CB");
XT = cat(1,X,T);
% compute DNN output using forward function, statef contains
% batchnormlayers (learned means and learned var)
[Uf,statef] = forward(dnn_net,XT);
dnn_net.State = statef;% update the dnn_net.State (so that it has the BN layers updated learned means and var
Up = predict(dnn_net,XT);
% DNN output between predict call and forward call shoudl be the same in
% this case (becausae the dnn_net.State is updated with the same learned mean/var from forward calls.
% However, it is not the case. The err is quite large
plot(Uf(1,:),'r-');hold on; plot(Up(1,:),'b-');
err = mse(Uf(1,:),Up(1,:))
err =
1x1 dlarray 0.3134

Risposte (1)

To get agreement, you need to allow the trained mean and variance to converge:
numLayer = 9;
numNeurons = 80;
layers = featureInputLayer(2);% 2 input features
for i = 1:numLayer-1
layers = [
layers
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
geluLayer
];
end
layers = [
layers
fullyConnectedLayer(4)];
dnn_net = dlnetwork(layers,'Initialize',true);
dnn_net = expandLayers(dnn_net);% expand the layers in residual blocks
dnn_net = dlupdate(@double,dnn_net);
X = 2*pi*dlarray(randn(1,1000),"CB");
T = 2*pi*dlarray(randn(1,1000),"CB");
XT = cat(1,X,T);
% compute DNN output using forward function, statef contains
% batchnormlayers (learned means and learned var)
for i=1:500
[Uf,statef] = forward(dnn_net,XT);
dnn_net.State = statef;% update the dnn_net.State (so that it has the BN layers updated learned means and var
end
Up = predict(dnn_net,XT);
% DNN output between predict call and forward call shoudl be the same in
% this case (becausae the dnn_net.State is updated with the same learned mean/var from forward calls.
% However, it is not the case. The err is quite large
plot(Uf(1,:),'r-');hold on; plot(Up(1,:),'--b');
err = mse(Uf(1,:),Up(1,:))
err =
1x1 dlarray 2.6776e-12

6 Commenti

Thanks but how to get them to converge? I have also tried calling 1000 times forward before calling predict (on the same XT) but still nothing is changed.

Did you run my code above, which demonstrates that convergence does occur? How is the result you are getting different from what is shown there?
Oh. sorry i overlooked. THanks for that. it works. However, when i checked the dlgradient results from forward() and predict(), they are inconsistent (although the U are consistent now). Here are my modified codes: (thanks for helping).
numLayer = 9;
numNeurons = 80;
layers = featureInputLayer(2);% 2 input features
for i = 1:numLayer-1
layers = [
layers
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
geluLayer
];
end
% 4 output channels as final layer
layers = [
layers
fullyConnectedLayer(4)];
% setup the nn
dnn_net = dlnetwork(layers,'Initialize',true);
%dnn_net = expandLayers(dnn_net);% expand the layers in residual blocks
dnn_net = dlupdate(@double,dnn_net);
X = 2*pi*dlarray(rand(1,1000),"CB");
T = 2*pi*dlarray(rand(1,1000),"CB");
XT = cat(1,X,T);
% compute DNN output using forward function, statef contains
% batchnormlayers (learned means and learned var)
for i=1:500
[Uf,statef] = forward(dnn_net,XT);
dnn_net.State = statef;% update the dnn_net.State (so that it has the BN layers updated learned means and var
end
Up1 = predict(dnn_net,XT);
% DNN output between predict call and forward call shoudl be the same
err1 = mse(Uf(1,:),Up1(1,:));
[Uf, Uxf, Utf, Uxxf, Uttf, dnn_states_f,gradients_f] = dlfeval(@test,dnn_net,X,T,0);
dnn_net.State = dnn_states_f;
[Up, Uxp, Utp, Uxxp, Uttp, dnn_states_p,gradients_p] = dlfeval(@test,dnn_net,X,T,1);
err2 = mse(Up,Uf);
err2x = mse(Uxp,Uxf);
err2xx = mse(Uxxp,Uxxf);
======= and this is the functionn file test.m =================
function [U, Ux, Ut, Uxx, Utt, dnn_states,gradients] = test(dnn_net,X,T,mode)
XT = cat(1,X,T);% 2 input features on one input channel
if mode == 0
[U , dnn_states]= forward(dnn_net,XT); % return the dnn_states so that batchnormalization layer can get updated during custom training (predict and forward will be consistent)
else
[U, dnn_states] = predict(dnn_net,XT);
end
% Calculate derivatives with respect to X and T.
gradientsU1 = dlgradient(sum(U(1,:),"all"),{X,T},'EnableHigherDerivatives',true);
Ux = gradientsU1{1};
Ut = gradientsU1{2};
% Calculate second-order derivatives with respect to T.
Utt = dlgradient(sum(Ux,"all"),T);
% Calculate second-order derivatives with respect to X.
Uxx = dlgradient(sum(Ux,"all"),X);
loss = sum(U,"all");
gradients = dlgradient(loss,dnn_net.Learnables);
end
No, they are not expected to have consistent gradients. When using forward(), the batch normalization uses the mean and variance of activation data from previous layers. These are considered moving values, i..e., functions of the learnable parameters, by dlgradient.
When using predict(), however, the normalization uses finalized population statistics for the mean and variance, which are treated as constants by dlgradient().
Hi Matt,
Thanks a lot for your explaination, Is there a way to set up BN layer so that it will be using a supplied mean and var for normalization when calling forward()? In my application, it is important for the BN to use the mean and var as constants for differentiation of U with respect to X and T. Thanks in advance.
Then don't use forward(). Use predict().

Accedi per commentare.

Categorie

Prodotti

Release

R2024a

Richiesto:

Tom
il 25 Ago 2024

Modificato:

il 4 Set 2024

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by