Complex number gradient using 'dlgradient' in conjunction with neural networks
6 visualizzazioni (ultimi 30 giorni)
Mostra commenti meno recenti
Dr. Veerababu Dharanalakota
il 7 Apr 2023
Commentato: Walter Roberson
il 7 Apr 2023
Hello All,
I am trying to find the gradient of a function , where C is a complex-valued constant, is a feedforward neural network, x is the input vector (real-valued) and θ are the parameters (real-valued). The output of the neural network is a real-valued array. However, due to the presence of complex constant C, the function f is becoming a complex-valued. I would like to find its gradient with respect to the input vector x.
I tried to follow the method mentioned in https://in.mathworks.com/help/deeplearning/ref/dlarray.dlgradient.html which is given below (modified)
clc;
clear all;
x = linspace(1,10,5); % Real-valued array
x = dlarray(x,"CB"); % Converting to deeplearning array
[y, grad] = dlfeval(@gradFun,x);
grad = extractdata(grad)
% Complex-function
function y = complexFun(x)
y = (2+3j)*x.^2;
end
% Function to calculate complex gradient
function [y,grad] = gradFun(x)
y = complexFun(x);
y = real(y);
grad = dlgradient(sum(y,"all"),x,'EnableHigherDerivatives',true);
end
The method is successfully calculating the gradient of a complex number (of course, giving conjugate output). I tried implementing the same by replacing the real-valued function with . When I did this, I am encoutering the following error
"Encountered complex value when computing gradient with respect to an output of fullyconnect. Convert all outputs of fullyconnect to real".
I would be grateful if anyone could show a way to fix the error and calculate the gradients.
Thank you,
Dr. Veerababu Dharanalakota
0 Commenti
Risposta accettata
Walter Roberson
il 7 Apr 2023
The derivative of C*f(x) can be calculated using the chain rule for multiplication: dC/dx*f(x) + C*df/dx. But when C is constant then no matter whether it is real or complex valued, dC/dx is 0. Therefore the derivative of C*f(x) is C*df/dx. The same logic applies to second derivatives.
Therefore the gradient of C*f(x) is C times the gradient of f(x). And if f(x) is real valued as indicated, and C is complex valued then unless the gradient is 0 it follows that the gradient of C*f(x) will be complex valued. Which dlgradient will refuse to work with.
So take the dlgradient of f(x) and multiply the result by C. That should at least postpone the problem.
2 Commenti
Più risposte (0)
Vedere anche
Categorie
Scopri di più su Custom Training Loops in Help Center e File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!