How does selfAttentionLayer work,implementing validation with brief code?

13 visualizzazioni (ultimi 30 giorni)
How does selfAttentionLayer work in detail every step of the way, can you simply reproduce its working process based on the paper formula? Thus verifying the selfAttentionLayer it's correctness and consisency.
official description:
A self-attention layer computes single-head or multihead self-attention of its input.
The layer:
  1. Computes the queries, keys, and values from the input
  2. Computes the scaled dot-product attention across heads using the queries, keys, and values
  3. Merges the results from the heads
  4. Performs a linear transformation on the merged result

Risposta accettata

cui,xingxing
cui,xingxing il 11 Gen 2024
Modificato: cui,xingxing il 27 Apr 2024
Here I have provided myself a simple code workflow with only 2 dimensions, "CT", to illustrate how each step works.
Note that each variable followed by a comment has a dimension representation.
%% 验证selfAttentionLayer操作计算与自己手算一致性!
XTrain = dlarray(rand(10,20));% CT
numClasses = 4;
numHeads = 6;
queryDims = 48; % N1=48
layers = [inputLayer(size(XTrain),"CT");
selfAttentionLayer(numHeads,queryDims,NumValueChannels=12,OutputSize=15,Name="sa");
layerNormalizationLayer;
fullyConnectedLayer(numClasses);
softmaxLayer];
net = dlnetwork(layers);
% analyzeNetwork(net)
XTrain = dlarray(XTrain,"CT");
[act1,act2] = predict(net,XTrain,Outputs=["input","sa"]);
act1 = extractdata(act1);% CT
act2 = extractdata(act2);% CT
% layer params
layerSA = net.Layers(2);
QWeights = layerSA.QueryWeights; % N1*C
KWeights = layerSA.KeyWeights;% N1*C
VWeights = layerSA.ValueWeights;% N2*C
outputW = layerSA.OutputWeights;% N3*N2
Qbias = layerSA.QueryBias; % N1*1
Kbias = layerSA.KeyBias;% N1*1
Vbias = layerSA.ValueBias; % N2*1
outputB = layerSA.OutputBias;% N3*1
% step1
q = QWeights*act1+Qbias; % N1*T
k = KWeights*act1+Kbias;% N1*T
v = VWeights*act1+Vbias;% N2*T
% step2,multiple heads
numChannelsQPerHeads = size(q,1)/numHeads;% 1*1
numChannelsVPerHeads = size(v,1)/numHeads;% 1*1
attentionM = cell(1,numHeads);
for i = 1:numHeads
idxQRange = numChannelsQPerHeads*(i-1)+1:numChannelsQPerHeads*i;
idxVRange = numChannelsVPerHeads*(i-1)+1:numChannelsVPerHeads*i;
qi = q(idxQRange,:);% diQ*T
ki = k(idxQRange,:);% diQ*T
vi = v(idxVRange,:);% diV*T
% attention
dk = size(qi,1);% 1*1
attentionScores = mysoftmax(ki'*qi./sqrt(dk));% T*T, note matlab interal code use k'*q,not q'*k
attentionM{i} = vi*attentionScores; % diV*T
end
%step3,merge attentionM
attention = cat(1,attentionM{:}); % N2*T,N2 = diV*numHeads
%step4,output linear projection
act_ = outputW*attention+outputB;% N3*T
act2(1,:)
ans = 1×20
-0.5919 -0.5888 -0.5905 -0.5916 -0.5902 -0.5956 -0.5936 -0.5910 -0.5906 -0.5922 -0.5943 -0.5926 -0.5915 -0.5920 -0.5947 -0.5935 -0.5925 -0.5932 -0.5917 -0.5884
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
act_(1,:)
ans = 1×20
-0.5919 -0.5888 -0.5905 -0.5916 -0.5902 -0.5956 -0.5936 -0.5910 -0.5906 -0.5922 -0.5943 -0.5926 -0.5915 -0.5920 -0.5947 -0.5935 -0.5925 -0.5932 -0.5917 -0.5884
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
I have reproduced its working process in the simplest possible way,hope it help others.
function out = mysoftmax(X,dim)
arguments
X
dim = 1;
end
% X = X-max(X,[],dim); %防止X过大导致取exp的值为Inf
X = exp(X);
out = X./sum(X,dim);
end
-------------------------Off-topic interlude, 2024-------------------------------
I am currently looking for a job in the field of CV algorithm development, based in Shenzhen, Guangdong, China,or a remote support position. I would be very grateful if anyone is willing to offer me a job or make a recommendation. My preliminary resume can be found at: https://cuixing158.github.io/about/ . Thank you!
Email: cuixingxing150@gmail.com

Più risposte (0)

Categorie

Scopri di più su Parallel and Cloud in Help Center e File Exchange

Prodotti


Release

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by