Using self attention layer with 2D images
    26 visualizzazioni (ultimi 30 giorni)
  
       Mostra commenti meno recenti
    
Hi,
I am wondering how to use the selfattention layer in image calssaifcation using CNN without we need to flatten the data  as explained in this example: 
% load digit dataset
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', 'nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders', true, 'LabelSource', 'foldernames');
[imdsTrain, imdsValidation] = splitEachLabel(imds, 0.7, 'randomized');
% define network architecture
layers = [
    imageInputLayer([28 28 1], 'Name', 'input')
    convolution2dLayer(3, 32, 'Padding', 'same', 'Name', 'conv1')
    batchNormalizationLayer('Name', 'bn1')
    reluLayer('Name', 'relu1')
    maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool1')
    convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv2')
    batchNormalizationLayer('Name', 'bn2')
    reluLayer('Name', 'relu2')
    maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool2')
    flattenLayer('Name', 'flatten')
    selfAttentionLayer(8, 64, 'Name', 'self_attention')
    fullyConnectedLayer(10, 'Name', 'fc')
    softmaxLayer('Name', 'softmax')
    classificationLayer('Name', 'output')]
% set training options
options = trainingOptions('sgdm', ...
    'InitialLearnRate', 0.01, ...
    'MaxEpochs', 5, ...
    'Shuffle', 'every-epoch', ...
    'ValidationData', imdsValidation, ...
    'ValidationFrequency', 30, ...
    'Verbose', false, ...
    'Plots', 'training-progress')
% training the network
net = trainNetwork(imdsTrain, layers, options);
0 Commenti
Risposte (1)
  Neha
      
 il 21 Nov 2023
        Hi Mahmoud,
I understand that you want to use self-attention layer in image classification. The self-attention layer, also known as the multi-head self-attention layer, is commonly employed in Transformer models like BERT and vision transformers (ViT). Its primary function is to understand the relationships between positions within the input data. This input data is usually sequential, representing either temporal sequences or 1D spatial information. Therefore it is necessary to use the "flattenLayer" to ensure that the input data to the "selfAttentionLayer" is one directional.
Hope this helps!
0 Commenti
Vedere anche
Categorie
				Scopri di più su Deep Learning Toolbox in Help Center e File Exchange
			
	Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!