Why happens all the gradients of the generator are zero (from the beginning to the end) when training a GAN
11 visualizzazioni (ultimi 30 giorni)
Mostra commenti meno recenti
I want to train a generator which generates values of a sin function. However, when training a GAN, all the gradients of the generator are zero. I do not know what the problem is. Could anyone help me?
The code is listed as follows:
batch_size = 64;
n_ideas = 5;
art_components = 15;
step = 2/(art_components-1);
points = -1:step:1;
paint_points = repmat(points,batch_size,1);
Generator = [
featureInputLayer(n_ideas)
fullyConnectedLayer(128)
reluLayer
fullyConnectedLayer(art_components)
];
Discriminator = [
featureInputLayer(art_components)
fullyConnectedLayer(128)
reluLayer
fullyConnectedLayer(1)
sigmoidLayer
];
net_g = dlnetwork(Generator);
net_d = dlnetwork(Discriminator);
lr = 0.0001;
decay = 0.90;
sqdecay = 0.999;
avg_decay_g = [];
avd_sqdecay_g = [];
avg_decay_d = [];
avd_sqdecay_d = [];
for e=1:10000
artis_paintings = dlarray(single(artist_work(art_components,paint_points)),'BC');
% update learnable parameters of discriminator
g_ideas = dlarray(single(randn(batch_size,n_ideas)),'BC');
g_paintings = forward(net_g,g_ideas);
[loss_d,gradient_d,score_d] = ...
dlfeval(@d_loss,net_d,artis_paintings,g_paintings);
[net_d, avg_decay_d, avd_sqdecay_d] = ...
adamupdate(net_d,gradient_d,avg_decay_d,avd_sqdecay_d,e,lr,decay,sqdecay);
% update learnable parameters of generator
g_ideas = dlarray(single(randn(batch_size,n_ideas)),'BC');
g_paintings = forward(net_g,g_ideas);
prob_artist1 = forward(net_d,g_paintings);
[loss_g,gradient_g,score_g] = ...
dlfeval(@g_loss,net_g,prob_artist1);
[net_g, avg_decay_g, avd_sqdecay_g] = ...
adamupdate(net_g,gradient_g,avg_decay_g,avd_sqdecay_g,e,lr,decay,sqdecay);
end
function [loss_d,gradient_d,score_d] = ...
d_loss(net_d,artis_paintings,g_paintings)
% calculate loss
prob_artist0 = forward(net_d,artis_paintings);
prob_artist1 = forward(net_d,g_paintings);
score_d = mean(1-prob_artist1);
loss_d = -mean(log(prob_artist0)) - mean(log(1-prob_artist1));
% calculate gradients
gradient_d = dlgradient(loss_d, net_d.Learnables);
end
function [loss_g,gradient_g,score_g] = ...
g_loss(net_g,prob_artist1)
score_g = mean(prob_artist1);
% calculate gradients
loss_g = -mean(log(prob_artist1));
gradient_g = dlgradient(loss_g, net_g.Learnables);
end
function paintings=artist_work(art_components,paint_points)
r = 0.02 * randn(1,art_components);
paintings = sin(paint_points *pi) + r;
end
0 Commenti
Risposta accettata
Richard
il 5 Nov 2022
All of the calculations that are "between" the variables you want gradients with respect to, and the loss value, need to be contained inside the function that you pass to dleval. If they are not, the dlgradient call will not know they have occurred and think there is no dependency between the outputs and inputs, hence gradients are all zero.
In this case, you must ensure that the "forward(net)" calls are inside the loss functions. You have done this correctly for the discriminator loss, but for the generator loss you need to pass in both the generator and disciminator networks and call forward on each one inside g_loss:
function [loss_g,gradient_g,score_g] = g_loss(net_g,net_d,g_ideas)
g_paintings = forward(net_g,g_ideas);
prob_artist1 = forward(net_d,g_paintings);
score_g = mean(prob_artist1);
% calculate gradients
loss_g = -mean(log(prob_artist1));
gradient_g = dlgradient(loss_g, net_g.Learnables);
end
2 Commenti
Richard
il 6 Nov 2022
Thanks for the feedback @You Jinkun. I will submit a documention enhancement request regarding this aspect of the dlfeval/dlgradient interaction.
You can also use the "How useful was this information?" section at the bottom of any of our doc pages to directly submit feedback to our doc team if there is a specific page that you think could be improved (clicking on the rating opens a text field for submitting a specific comment).
Più risposte (0)
Vedere anche
Categorie
Scopri di più su Image Data Workflows in Help Center e File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!