The adamupdate function in MATLAB R2024b incorrectly uses uint32 with sqrt and exhibits state corruption, causing errors even in minimal test cases."

10 visualizzazioni (ultimi 30 giorni)
% Test adamupdate function
clc;
clear;
% Define test parameters
learnable = dlarray(randn(5, 1)); % Example learnable parameter
gradient = dlarray(randn(5, 1)); % Example gradient
state = []; % Initial state (empty)
optimizer = trainingOptions('adam', 'InitialLearnRate', 0.01); % Example optimizer
timeStep = uint32(1); % Initial time step
try
% Perform a single adamupdate
updatedLearnable = adamupdate(learnable, gradient, state, optimizer, timeStep);
% Display results
disp('adamupdate test successful!');
disp('Updated Learnable:');
disp(updatedLearnable);
catch ME
% Display error message
disp('adamupdate test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace to help Mathworks track the problem.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate test failed!
Error: Undefined function 'sqrt' for input arguments of type 'uint32'.
Stack Trace:
9x1 struct array with fields: file name line
% Perform a second adam update to test state persistence.
timeStep = uint32(2);
try
% Perform a single adamupdate
updatedLearnable = adamupdate(learnable, gradient, state, optimizer, double(timeStep));
% Display results
disp('adamupdate second test successful!');
disp('Updated Learnable:');
disp(updatedLearnable2);
catch ME
% Display error message
disp('adamupdate second test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace to help Mathworks track the problem.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate second test failed!
Error: dlarray is supported only for full arrays of data type double, single, or logical, or for full gpuArrays of these data types.
Stack Trace:
12x1 struct array with fields: file name line
  1 Commento
Chika
Chika il 18 Mar 2025
error message"
:
adamupdate test failed!
Error: Undefined function 'sqrt' for input arguments of type 'uint32'.
Stack Trace:
7×1 struct array with fields:
file
name
line
adamupdate second test failed!
Error: dlarray is supported only for full arrays of data type double, single, or logical, or for full gpuArrays of these data types.
Stack Trace:
10×1 struct array with fields:
file
name
line

Accedi per commentare.

Risposta accettata

Joss Knight
Joss Knight il 22 Mar 2025
Well, I admit the error messages aren't very helpful but the basic problem is that passing a trainingOptions object in as an argument to adamupdate is not supported. See the documentation for the correct syntax.
  1 Commento
Chika
Chika il 22 Mar 2025
I am extremely grateful to Joss Knight for pointing out the error and his advis for me to look at the documentation for adamupdate function.

Accedi per commentare.

Più risposte (1)

Chika
Chika il 22 Mar 2025
% corrected code following the documentation as advised by Joss Knight
% Test adamupdate function (Built-in)
clc;
clear;
% Define test parameters
learnable = dlarray(randn(5, 1)); % Example learnable parameter
gradient = dlarray(randn(5, 1)); % Example gradient
averageGrad = zeros(size(learnable)); % Initialize average gradient
averageSqGrad = zeros(size(learnable)); % Initialize average squared gradient
iteration = 1; % Initial iteration
try
% Perform a single adamupdate
[updatedLearnable, averageGrad, averageSqGrad] = adamupdate(learnable, gradient, averageGrad, averageSqGrad, iteration);
% Display results
disp('adamupdate test successful!');
disp('Updated Learnable:');
disp(updatedLearnable);
disp('Average Gradient:');
disp(averageGrad);
disp('Average Squared Gradient:');
disp(averageSqGrad);
catch ME
% Display error message
disp('adamupdate test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate test successful!
Updated Learnable:
5x1 dlarray -0.7648 -1.0165 -0.0125 -0.5996 0.4997
Average Gradient:
5x1 dlarray 0.1693 -0.0385 0.0958 -0.0383 0.0295
Average Squared Gradient:
5x1 dlarray 0.0029 0.0001 0.0009 0.0001 0.0001
% Perform a second adam update to test state persistence.
iteration = 2;
try
% Perform a second adam update, passing in the updated state
[updatedLearnable2, averageGrad2, averageSqGrad2] = adamupdate(learnable, gradient, averageGrad, averageSqGrad, iteration);
% Display results
disp('adamupdate second test successful!');
disp('Updated Learnable:');
disp(updatedLearnable2);
disp('Average Gradient:');
disp(averageGrad2);
disp('Average Squared Gradient:');
disp(averageSqGrad2);
catch ME
% Display error message
disp('adamupdate second test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate second test successful!
Updated Learnable:
5x1 dlarray -0.7648 -1.0165 -0.0125 -0.5996 0.4997
Average Gradient:
5x1 dlarray 0.3217 -0.0732 0.1819 -0.0727 0.0560
Average Squared Gradient:
5x1 dlarray 0.0057 0.0003 0.0018 0.0003 0.0002

Categorie

Scopri di più su Image Data Workflows in Help Center e File Exchange

Prodotti


Release

R2024b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by