Al momento, stai seguendo questa domanda
- Vedrai gli aggiornamenti nel tuofeed del contenuto seguito.
- Potresti ricevere delle e-mail a seconda delle tuepreferenze per le comunicazioni.
I'm using VIT transformer in my code. How to convert the output of 1D layer of VIT into 2D with format SSCB?
8 Commenti
Hi Abdulrahman,
I cannot execute the code because visionTransformer requires Computer Vision Toolbox. To illustrate resolving your error, I had to adapt your given code from mathworks for input dimensions of 24 x 24 x 768, by adjusting the reshaping and processing steps accordingly. Here is update the code step by step:
% Get Vision Transformer model
net = visionTransformer;
% Create dummy input
input = dlarray(rand(24,24,768),'SSCB');
% Obtain output embedding from the last LayerNormalizationLayer
out = forward(net, input, 'Outputs', 'encoder_norm');
% Reshape output patch embedding
out = reshapePatchEmbedding(out);
function out = reshapePatchEmbedding(in)
% Remove output embedding corresponding to the class token from the input
out = in(2:end,:,:);
% Reshape the resulting embedding to the input format
WH = sqrt(size(out, 1));
C = size(out, 2);
out = reshape(out, WH, WH, C, []); % Shape is W x H x C x N
out = permute(out, [2, 1, 3, 4]); % Shape is H x W x C x N
% Convert to formatted dlarray
out = dlarray(out, 'SSCB');
end
So, in my updated code snippet, I changed the dummy input dimensions to 24 x 24 x 768 to match the specified input size. The reshaping function reshapePatchEmbedding has been adjusted to handle the new dimensions correctly. Please let me know if this helps resolve your issue.
Risposte (2)
1 Commento
Hi Abdulrahman,
I cannot execute the code because visionTransformer requires Computer Vision Toolbox. To illustrate resolving your error, I had to adapt your given code from mathworks for input dimensions of 24 x 24 x 768, by adjusting the reshaping and processing steps accordingly. Here is update the code step by step:
% Get Vision Transformer model
net = visionTransformer;
% Create dummy input
input = dlarray(rand(24,24,768),'SSCB');
% Obtain output embedding from the last LayerNormalizationLayer
out = forward(net, input, 'Outputs', 'encoder_norm');
% Reshape output patch embedding
out = reshapePatchEmbedding(out);
function out = reshapePatchEmbedding(in)
% Remove output embedding corresponding to the class token from the input
out = in(2:end,:,:);
% Reshape the resulting embedding to the input format
WH = sqrt(size(out, 1));
C = size(out, 2);
out = reshape(out, WH, WH, C, []); % Shape is W x H x C x N
out = permute(out, [2, 1, 3, 4]); % Shape is H x W x C x N
% Convert to formatted dlarray
out = dlarray(out, 'SSCB');
end
So, in my updated code snippet, I changed the dummy input dimensions to 24 x 24 x 768 to match the specified input size. The reshaping function reshapePatchEmbedding has been adjusted to handle the new dimensions correctly. Please let me know if this helps resolve your issue.
Vedere anche
Categorie
Tag
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!Si è verificato un errore
Impossibile completare l'azione a causa delle modifiche apportate alla pagina. Ricarica la pagina per vedere lo stato aggiornato.
Seleziona un sito web
Seleziona un sito web per visualizzare contenuto tradotto dove disponibile e vedere eventi e offerte locali. In base alla tua area geografica, ti consigliamo di selezionare: .
Puoi anche selezionare un sito web dal seguente elenco:
Come ottenere le migliori prestazioni del sito
Per ottenere le migliori prestazioni del sito, seleziona il sito cinese (in cinese o in inglese). I siti MathWorks per gli altri paesi non sono ottimizzati per essere visitati dalla tua area geografica.
Americhe
- América Latina (Español)
- Canada (English)
- United States (English)
Europa
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom(English)
Asia-Pacifico
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)
