attention
Syntax
Description
The attention operation focuses on parts of the input using weighted multiplication operations.
Examples
Specify the sizes of the queries, keys, and values.
querySize = 100; valueSize = 120; numQueries = 64; numValues = 80; numObservations = 32;
Create random arrays containing the queries, keys, and values. For the queries, specify the dlarray format "CBT" (channel, batch, time).
queries = dlarray(rand(querySize,numObservations, numQueries),"CBT");
keys = dlarray(rand(querySize,numObservations, numValues));
values = dlarray(rand(valueSize,numObservations, numValues));Specify the number of attention heads.
numHeads = 5;
Apply the attention operation.
[Y,weights] = attention(queries,keys,values,numHeads);
View the sizes and format of the output.
size(Y)
ans = 1×3
120 32 64
dims(Y)
ans = 'CBT'
View the sizes and format of the weights.
size(weights)
ans = 1×4
80 64 5 32
dims(weights)
ans = 0×0 empty char array
You can use the attention function to implement the multihead self attention operation [1] that focuses on parts of the input.
Create the multiheadSelfAttention function, listed in the Multihead Self Attention Function section of the example. The multiheadSelfAttention function takes as input the data X, the number of heads, and the learnable weights for the queries, keys, values, and output data, and returns the multihead attention values.
The X input must be an unformatted dlarray object, where the first dimension corresponds to the input channels, the second dimension corresponds to the time or spatial dimension, and the third dimension corresponds to the batch dimension.
Create an array of sequence data.
numChannels = 10; numObservations = 128; numTimeSteps = 100; X = rand(numChannels,numObservations,numTimeSteps); X = dlarray(X); size(X)
ans = 1×3
10 128 100
Specify the number of heads for multihead attention.
numHeads = 8;
Initialize the learnable parameters for multihead attention.
The learnable query, key, and value weights must be
(numChannels*numHeads)-by-numChannelsarrays.The learnable output weights must be a
(numChannels*numHeads)-by-(numChannels*numHeads)array.
outputSize = numChannels*numHeads; WQ = rand(outputSize,numChannels); WK = rand(outputSize,numChannels); WV = rand(outputSize,numChannels); WO = rand(outputSize,outputSize);
Apply the multihead self attention operation.
Y = multiheadSelfAttention(X,numHeads,WQ,WK,WV,WO);
View the size of the output. The output has size (numChannels*numHeads)-by-numObservations-by-(numTimeSteps).
size(Y)
ans = 1×3
80 128 100
Multihead Self Attention Function
The multiheadSelfAttention function takes as input the data X, the number of heads, and the learnable weights for the queries, keys, values, and output data, and returns the multihead attention values.
The
Xinput must be an unformatteddlarrayobject, where the first dimension corresponds to the input channels, the second dimension corresponds to the time or spatial dimension, and the third dimension corresponds to the batch dimension.The learnable query, key, and value weight matrices are
(numChannels*numHeads)-by-numChannelsmatrices.The learnable output weights matrix is a
(numChannels*numHeads)-by-(numChannels*numHeads)matrix.
function Y = multiheadSelfAttention(X,numHeads,WQ,WK,WV,WO) queries = pagemtimes(WQ,X); keys = pagemtimes(WK,X); values = pagemtimes(WV,X); A = attention(queries,keys,values,numHeads,DataFormat="CBT"); Y = pagemtimes(WO,A); end
You can use the attention function to create a function that applies the Luong attention operation to its input. Create the luongAttention function, listed at the end of the example, that applies the Luong attention operation.
Specify the array sizes.
numHiddenUnits = 100; latentSize = 16;
Create random arrays containing the input data.
hiddenState = dlarray(rand(numHiddenUnits,1)); Z = dlarray(rand(latentSize,1)); weights = dlarray(rand(numHiddenUnits,latentSize));
Apply the luongAttention function.
[context,scores] = luongAttention(hiddenState,Z,weights);
View the sizes of the outputs.
size(context)
ans = 1×2
16 1
size(scores)
ans = 1×2
1 1
Luong Attention Function
The luongAttention function returns the context vector and attention scores according to the Luong "general" scoring [2]. This operation is equivalent to dot-product attention with queries, keys, and values specified as the hidden state, the weighted latent representation, and the latent representation, respectively.
function [context,scores] = luongAttention(hiddenState,Z,weights) numHeads = 1; queries = hiddenState; keys = pagemtimes(weights,Z); values = Z; [context,scores] = attention(queries,keys,values,numHeads, ... Scale=1, ... DataFormat="CBT"); end
Input Arguments
Queries, specified as a dlarray object.
queries can have at most one "S" (spatial)
or "T" (time) dimension. Any dimensions in
queries labeled "U" (unspecified) must be
singleton. If queries is an unformatted dlarray
object, then specify the data format using the DataFormat
option.
The size of the "C" (channel) dimension in keys must
match the size of the corresponding dimension in queries.
The size of the "B" (batch) dimension in queries, keys, and values must match.
Keys, specified as a dlarray object or a numeric array.
If keys is a formatted dlarray object, then
its format must match the format of queries. If
keys is not a formatted dlarray object, then the
function uses the same format as queries.
The size of any "S" (spatial) or "T" (time) dimensions in keys must match the size of the corresponding dimension in values.
The size of the "C" (channel) dimension in keys must
match the size of the corresponding dimension in queries.
The size of the "B" (batch) dimension in queries, keys, and values must match.
Values, specified as a dlarray object or a numeric array.
If values is a formatted dlarray object, then
its format must match the format of queries. Otherwise, the
function uses the same format as queries.
The size of any "S" (spatial) or "T" (time) dimensions in keys must match the size of the corresponding dimension in values.
The size of the "B" (batch) dimension in queries, keys, and values must match.
Number of attention heads, specified as a positive integer.
Each head performs a separate linear transformation of the input and computes attention weights independently. The layer uses these attention weights to compute a weighted sum of the input representations, generating a context vector. Increasing the number of heads lets the model capture different types of dependencies and attend to different parts of the input simultaneously. Reducing the number of heads can lower the computational cost of the layer.
The value of numHeads must evenly divide the size of the
"C" (channel) dimension of queries,
keys, and values.
Name-Value Arguments
Specify optional pairs of arguments as
Name1=Value1,...,NameN=ValueN, where Name is
the argument name and Value is the corresponding value.
Name-value arguments must appear after other arguments, but the order of the
pairs does not matter.
Before R2021a, use commas to separate each name and value, and enclose
Name in quotes.
Example: attention(queries,keys,values,numHeads,DataFormat="CBT")
applies the attention operation for unformatted data and specifies the data format
"CBT" (channel, batch, time).
Description of the data dimensions, specified as a character vector or string scalar.
A data format is a string of characters, where each character describes the type of the corresponding data dimension.
The characters are:
"S"— Spatial"C"— Channel"B"— Batch"T"— Time"U"— Unspecified
For example, consider an array that represents a batch of sequences where the first,
second, and third dimensions correspond to channels, observations, and time steps,
respectively. You can describe the data as having the format "CBT"
(channel, batch, time).
You can specify multiple dimensions labeled "S" or "U".
You can use the labels "C", "B", and
"T" once each, at most. The software ignores singleton trailing
"U" dimensions after the second dimension.
If the input data is not a formatted dlarray object, then you must
specify the DataFormat option.
For more information, see Deep Learning Data Formats.
Data Types: char | string
Since R2026a
Number of query groups (equivalent to the number of key-value heads), specified as one of these values:
"num-heads"— Use thenumHeadsargument value.Positive integer — Use the specified number of query groups. This value must divide the
numHeadsargument value.
The value of NumQueryGroups specifies the type of attention operation:
For multihead attention, set
NumQueryGroupstonumHeads.For multiquery attention (MQA), set
NumQueryGroupsto1.For grouped-query attention (GQA), set
NumQueryGroupsto a positive integer between1andnumHeads.
Using multiquery attention and grouped-query attention can reduce memory and computation time for large inputs.
When the number of query groups is greater than 1, the operation creates groups of query channels-per-head, and applies the attention operation within each group.
For example, for six heads with three query groups, the operation splits the query channels into the heads (h1, …, h6) and then creates the groups of heads g1=(h1,h2), g2=(h3,h4), and g3=(h5,h6). The operation also splits the key and value channels into the heads g1, g2, and g3.
When the number of query groups matches the number of heads, the groups have one head each and is equivalent to multihead attention. When the number of query groups is 1, then all the heads are in the same group and is equivalent to multiquery attention.
Data Types: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | char | string
Multiplicative factor for scaled dot-product attention [1], specified as one of these values:
"auto"— Multiply the dot-product by , where dk denotes the number of channels in the keys divided by the number of heads.Numeric scalar — Multiply the dot-product by the specified scale factor.
Data Types: single | double | char | string
Mask indicating which elements of the input correspond to padding values,
specified as a dlarray object, a logical array, or a binary-valued
numeric array.
The function prevents and allows attention to elements of input data key-value
pairs when the corresponding element in PaddingMask is
0 and 1, respectively.
The padding mask can be formatted or unformatted:
If
PaddingMaskis a formatteddlarrayobject, then the dimension labels ofPaddingMaskmust match the dimension labels of thekeys, ignoring any"C"(channel) and"U"(unspecified) dimensions (since R2026a).Before R2026a: If
PaddingMaskis a formatteddlarrayobject, then its format must match that of the keys.If
PaddingMaskis not a formatteddlarrayobject, then it must have the same number of nonchannel dimensions as the keys (since R2026a). In this case, the function uses the same format as the keys, ignoring any missing"C"dimensions.Before R2026a: If
PaddingMaskis not a formatteddlarrayobject, then it must have the same number of dimensions as the keys. In this case, the function uses the same format as the keys.
The padding mask can have different layouts:
The size of the
"S"(spatial),"T"(time), and"B"(batch) dimensions inPaddingMaskmust match the size of the corresponding dimensions inkeysandvalues.The padding mask can have any number of
"U"(unspecified) dimensions. The software uses the values in the first"U"(unspecified) dimension.The
"U"dimensions of the padding mask can be nonsingleton (since R2026a). In this case, the software uses the values in the first index to indicate padding values.Before R2026a: The
"U"dimensions of the padding mask must singleton.The padding mask can have any number of channels. The software uses the values in the first channel only to indicate padding values.
The default value is a logical array of ones with the same size as
keys.
Attention mask indicating which elements to include when applying the attention operation, specified as one of these values:
"none"— Do not prevent attention to elements with respect to their positions. IfAttentionMaskis"none", then the software prevents attention using only the padding mask."causal"— Prevent elements in position m in the"S"(spatial) or"T"(time) dimension of the input queries from providing attention to the elements in positions n, where n is greater than m in the corresponding dimension of the input keys and values. Use this option for autoregressive models.Logical or numeric array — Prevent attention to elements of the input keys and values when the corresponding element in the specified array is
0. The specified array must be an Nk-by-Nq matrix or a Nk-by-Nq-by-numObservationsarray, Nk is the size of the"S"(spatial) or"T"(time) dimension of the input keys, Nq is the size of the corresponding dimension of the input queries, andnumObservationsis the size of the"B"dimension in the input queries.
Data Types: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | logical | char | string
Dropout probability for the attention weights, specified as a scalar in the range [0, 1).
Data Types: single | double
Output Arguments
Result of attention operation, returned as a dlarray object.
If queries is a formatted dlarray object, then
Y is a formatted dlarray object with the same
dimension labels as queries. The size of the
"C" (channel) dimension of Y is the same as
the size of the corresponding dimension in values. The size of the
"S" (spatial) or "T" dimension of
Y is the same size as the corresponding dimension in
queries.
If queries is not a formatted dlarray object,
then Y is an unformatted dlarray object.
Attention weights, returned as an unformatted dlarray
object.
weights is a
Nk-by-Nq-by-numHeads-by-numObservations
array, where Nk is the size of the
"S" (spatial) or "T" (time) dimension of
keys, Nq is the size of
the corresponding dimension in queries, and
numObservations is the size of the "B" (batch)
dimension in queries.
Algorithms
The attention operation focuses on parts of the input using weighted multiplication operations.
The single-head dot-product attention operation is given by
where:
Q denotes the queries.
K denotes the keys.
V denotes the values.
denotes the scaling factor.
M is a mask array of ones and zeros.
p is the dropout probability.
The mask operation includes or excludes the values of the matrix multiplication by setting values of the input to for zero-valued mask elements. The mask is the union of the padding and attention masks. The softmax function normalizes the value of the input data across the channel dimension such that it sums to one. The dropout operation sets elements to zero with probability p.
The multihead attention operation applies the attention operation across multiple heads. Each head uses its own learnable query, key, and value projection matrices.
The multihead attention operation for the queries Q, keys K, and values V is given by
where
h is the number of heads.
WO is a learnable projection matrix for the output.
For the multihead attention operation, each learnable projection matrix for the queries, keys, and values are composed of concatenated matrices Wi, where i indexes over the heads.
The head operation is given by
where:
i indexes over the heads.
WQ is a learnable projection matrix for the queries.
WK is a learnable projection matrix for the keys.
WV is a learnable projection matrix for the values.
The multiquery attention (MQA) operation applies the attention operation across multiple heads. Each head uses its own learnable query projection matrix. The operation uses the same learnable key and value projection matrices across all heads.
The multiquery attention operation for the queries Q, keys K, and values V is given by
where
h is the number of heads.
WO is a learnable projection matrix for the output.
For the multiquery attention operation, only the learnable projection matrix for the queries is composed of concatenated matrices Wi, where i indexes over the heads.
For multiquery attention, the head operation is given by
where:
i indexes over the heads.
WQ is a learnable projection matrix for the queries.
WK is a learnable projection matrix for the keys.
WV is a learnable projection matrix for the values.
The grouped query attention (GQA) operation applies the attention operation across several heads. The operation partitions the heads into groups that use the same learnable query projection matrix. The operation uses the same learnable key and value projection matrices for each group of query heads.
The grouped-query attention operation for the queries Q, keys K, and values V is given by
where
h is the number of heads.
WO is a learnable projection matrix for the output.
For the grouped query attention operation:
The learnable projection matrix for the queries is composed of concatenated matrices Wi, where i indexes over the heads.
The learnable projection matrices for the keys and values are composed of concatenated matrices Wj, where j indexes over the groups.
When the number of query groups is greater than 1, the operation creates groups of query channels-per-head, and applies the attention operation within each group.
For example, for six heads with three query groups, the operation splits the query channels into the heads (h1, …, h6) and then creates the groups of heads g1=(h1,h2), g2=(h3,h4), and g3=(h5,h6). The operation also splits the key and value channels into the heads g1, g2, and g3.
When the number of query groups matches the number of heads, the groups have one head each and is equivalent to multihead attention. When the number of query groups is 1, then all the heads are in the same group and is equivalent to multiquery attention.
The head operation is given by
where:
i indexes over the heads.
g(i) is the group number of head i.
WQ is a learnable projection matrix for the queries.
WK is a learnable projection matrix for the keys.
WV is a learnable projection matrix for the values.
The self-attention operation is equivalent to setting the queries, keys, and values to the input data:
selfAttention(X) = attention(X,X,X)
multiheadSelfAttention(X) = multiheadAttention(X,X,X)
multiquerySelfAttention(X) = multiqueryAttention(X,X,X)
groupedQuerySelfAttention(X) = groupedQueryAttention(X,X,X)
In each of these cases, the operation still uses separate learnable parameters for the keys, queries, and values. For example, to calculate multiheadSelfAttention(X), the operation uses the head operation given by
where:
i indexes over the heads.
X is the input data.
WQ is a learnable projection matrix for the queries.
WK is a learnable projection matrix for the keys.
WV is a learnable projection matrix for the values.
Most deep learning networks and functions operate on different dimensions of the input data in different ways.
For example, an LSTM operation iterates over the time dimension of the input data, and a batch normalization operation normalizes over the batch dimension of the input data.
To provide input data with labeled dimensions or input data with additional layout information, you can use data formats.
A data format is a string of characters, where each character describes the type of the corresponding data dimension.
The characters are:
"S"— Spatial"C"— Channel"B"— Batch"T"— Time"U"— Unspecified
For example, consider an array that represents a batch of sequences where the first,
second, and third dimensions correspond to channels, observations, and time steps,
respectively. You can describe the data as having the format "CBT"
(channel, batch, time).
To create formatted input data, create a dlarray object and specify the format using the second argument.
To provide additional layout information with unformatted data, specify the format using the
DataFormat
argument.
For more information, see Deep Learning Data Formats.
References
[1] Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. "Attention is all you need." Advances in neural information processing systems 30 (December 2017): 6000-6010. https://papers.nips.cc/paper_files/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html.
[2] Luong, Minh-Thang, Hieu Pham, and Christopher D. Manning. "Effective approaches to attention-based neural machine translation." arXiv preprint arXiv:1508.04025 (2015).
Extended Capabilities
The attention function
supports GPU array input with these usage notes and limitations:
When at least one of these input arguments is a
gpuArrayobject or adlarrayobject with underlying data of typegpuArray, this function runs on the GPU.querieskeysvalues
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
Version History
Introduced in R2022bThe value of NumQueryGroups
specifies the type of attention operation:
For multihead attention, set
NumQueryGroupstonumHeads.For multiquery attention (MQA), set
NumQueryGroupsto1.For grouped-query attention (GQA), set
NumQueryGroupsto a positive integer between1andnumHeads.
Using multiquery attention and grouped-query attention can reduce memory and computation time for large inputs.
The padding mask specified by PaddingMask does
not require a channel dimension. For masks that do not specify a channel dimension, the
operation assumes a singleton channel dimension.
See Also
padsequences | dlarray | dlgradient | dlfeval | lstm | gru | embed
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Seleziona un sito web
Seleziona un sito web per visualizzare contenuto tradotto dove disponibile e vedere eventi e offerte locali. In base alla tua area geografica, ti consigliamo di selezionare: .
Puoi anche selezionare un sito web dal seguente elenco:
Come ottenere le migliori prestazioni del sito
Per ottenere le migliori prestazioni del sito, seleziona il sito cinese (in cinese o in inglese). I siti MathWorks per gli altri paesi non sono ottimizzati per essere visitati dalla tua area geografica.
Americhe
- América Latina (Español)
- Canada (English)
- United States (English)
Europa
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)