Self-attention and transformers
Herman Kamper
2024-02, CC BY-SA 4.0
Issues with RNNs
Attention recap
Self-attention
Positional encodings
Multi-head attention
Masking the future in self-attention
Cross-attention
Transformer
1
Issues with RNNs
Mister Dursley of number was
Architectural
Even with changes to deal with long-range dependencies (e.g. LSTM),
more recent observations inevitably have a bigger influence on the
current hidden state than those that are far away.
Computational
• Future RNN states can’t be computed before past hidden states
have been computed.
• Computations over time steps are therefore not parallelisable.
• We just can’t get away from the “for loop” over time in the
forward pass over an RNN.
• So we can’t take advantage of the full power of batching on
GPUs, which wants several independent computations to be
performed at once.
2
Attention doesn’t have these problems
he threw me
softmax
ŷ1 ŷ2 ŷ3
hy het my gegooi </s> <s> he threw me
x1 x2 x3 x4 x5 y0 y1 y2 y3
Idea: Remove recurrence and rely solely on attention
Mister Dursley of number was
3
Intuition from the Google AI blog post:
4
Attention recap
One way to think of attention intuitively is as a soft lookup table:
Keys Values Keys Values
Query Query Output
+
Output
Computational graph:
c Attention
+
output
k1 a α × v1
k2 a α × v2
softmax
Keys k3 a α × v3 Values
kN a α × vN
Query q
5
Mathematically:
• Output of attention: Context vector.
N
α(q, kn )vn ∈ RD
X
c=
n=1
• Attention weight:
α(q, kn ) = softmaxn (a(q, kn ))
exp {a(q, kn )}
= PN ∈ [0, 1]
j=1 exp {a(q, kj )}
• Attention score:
a(q, kn ) ∈ R
6
Self-attention
y6
k1 k2 k3 k4 k5 k6 v6
q6
v1 v2 v3 v4 v5
x1 x2 x3 x4 x5 x6
i went to school to learn
7
Self-attention
k1 k2 k3 k4 k5 k6 v6
q6
v1 v2 v3 v4 v5
x1 x2 x3 x4 x5 x6
i went to school to learn
8
T
y6 yi =
X
αi,t vt
t=1
+
α6,1 v1
α6,6 v6
× × × × × × eai,t
α6,1 αi,t = PT
ai,j
j=1 e
softmax
a6,1 qi ⊤ kt
ai,t = √
Dk
k1 k2 k3 k4 k5 k6 v6 qt = Wq⊤ xt
q6
kt = Wk⊤ xt
v1 v2 v3 v4 v5 vt = Wv⊤ xt
x1 x2 x3 x4 x5 x6
i went to school to learn
Layer input: x1 , x2 , . . . xT
Layer output: y1 , y2 , . . . yT
9
In matrix form
Each of the T queries need to be compared to each of the T keys.
We can express this in a compact matrix form.
Stack all the queries, keys and values as rows in matrices:
Q ∈ RT ×Dk
K ∈ RT ×Dk
V ∈ RT ×Dv
We can then write all the dot products and weighting in a short
condensed form:
QK>
!
Attention(Q, K, V) = softmax √ V
Dk
If we denote the output as Y = Attention(Q, K, V), then we end up
with a result Y ∈ RT ×Dv .
The above holds in general for attention. For self-attention specifically,
we would have
Q = XWq
K = XWk
V = XWv
where the design matrix is X ∈ RT ×D , with D the dimensionality of
the input.
You can figure out the shapes for the W’s, e.g. Wk ∈ RD×Dk .
10
Self-attention: A new computational
block
A new block or layer, like an RNN or a CNN.
Can use this in both encoder and decoder modules. E.g. for machine
translation:
Figure from (Vaswani et al., 2017).
Sometimes “transformer” is used to refer to the self-attention layers
themselves, but other times it is used to refer to this specific encoder-
decoder model (which we will unpack in the rest of this note).
11
Positional encodings intuition
12
Positional encodings
In contrast to RNNs, there aren’t any order information in the inputs
of self-attention.
We can add positional encodings to the inputs:
pt ∈ RD
There is a unique pt for every input position. E.g. p10 will always be
the same for all input sequences.
How do we incorporate them? The positional encodings can be
concatenated to the inputs:
h i
x̃t = xt ; pt
But it is more common to just add them:1
x̃t = xt + pt
Where do the positional encodings pt come from?
Learned positional encodings
We can let the pt ’s be learnable parameters. This means we are adding
a learnable matrix P ∈ RD×T to all input sequences.
Problem: What if we have inputs that have longer lengths than T ?
(But this is still often used in practice.)
1
I like the idea of concatenation more than adding. But Benjamin van Niekerk
pointed out to me that if you pass x̃t through a single linear layer, then con-
catenation and addition are very similar: In both cases you end up with a new
representation that is a weighted sum of the original input and the positional
encoding (there are just additional weights specifically for the positional encoding
when you concatenate).
13
Represent position using sinusoids
Let’s use a single sinusoid as our pt :
1.00 d=6
0.75
0.50
Encoding feature value
0.25
0.00
0.25
0.50
0.75
1.00
0 10 20 30 40 50 60
Position
In this case, we would have unique positional feature value for inputs
roughly with lengths up to T = 36, and then the feature value would
repeat. This could be useful, if relative position at this scale is more
important than absolute position.
Let’s add a cosine to obtain pt :
1.00 d=6
d=7
0.75
0.50
Encoding feature value
0.25
0.00
0.25
0.50
0.75
1.00
0 10 20 30 40 50 60
Position
14
Now we would have unique positional encodings for a longer range.
But the model could also just decide that relative position matters
more.
We used sinusoids at a single frequency, so you are limited in the types
of relative relationships you can model. So let us add more sine and
cosine functions at different frequencies:
1.00
0.75
0.50
Encoding feature value
0.25
0.00
0.25
0.50
d=6
0.75 d=7
d=8
1.00 d=9
0 10 20 30 40 50 60
Position
Formally (Vaswani et al., 2017):
t
sin λ1
cos λt
1
sin t
λ 2
cos t
pt = λ2
..
.
t
sin
λ
D/2
t
cos λD/2
where
λm = 10 0002m/D
15
If we stack all these into P ∈ RD×T :
5
Encoding dimension
10
15
20
25
30
0 10 20 30 40 50
Position
There are formal reasons that this encodes relative position (Denk,
2019).2 But intuitively you should be able to see that periodicity
indicates that absolute position isn’t necessarily important.
In practice, however, this approach does not enable extrapolation to
sequences that are way longer than those seen during training (Hewitt,
2023).
(But it is still often used in practice. The original transformer paper
did this – look at the transformer diagram above.)
2
For a fixed offset between two positional encodings, there is a linear transfor-
mation to take you from the one to the other. E.g. you can go from p10 to p15
using some linear transformation, and this will be the same transformation needed
to go from p30 to p35 .
16
The clock analogy for positional encodings
Think of each pair of dimensions of pt as a clock rotating at a different
frequency. The position of the clock is uniquely determined by the
sine and cosine functions for that frequency.3
We have D/2 clocks. For each position t, we will have a specific
configuration of clocks. This tells us where in the input sequence we
are. This works, even if we never saw a long input sequence length
during training (the clocks just move on).
But there is also periodicity in how clocks change with different t. To
move from the configuration p10 to p15 , we need to change the clock
faces in some way (this can be done through a linear transformation).
But this way in which we change the clock faces, would be the same
as the transformation from p30 to p35 .
So in short, the sinusoidal positional encodings can tell us where we
are in the input, even if that position was never seen during training.
But it also allows for relative position information to be captured.
3
Analogy from Benjamin van Niekerk.
17
Multi-head attention
Hypothetical example:
Semantically related words: Syntactically related words:
k1 k2 k3 k4 k5 k6 v6 [1] k1 k2 k3 k4 k5 k6 v6 [2]
q6 q6
v1 v2 v3 v4 v5 v1 v2 v3 v4 v5
x1 x2 x3 x4 x5 x6 x1 x2 x3 x4 x5 x6
i went to school to learn i went to school to learn
Analogy: Each head is like a different kernel in a CNN
From Lena Voita’s blog:
18
Masking the future in self-attention
If we have a network or decoder that needs to be causal, then we
should ensure that it can only attend to the past when making the
current prediction.
E.g. if we are doing language modelling:
y4
k1 k2 k3 k4 v4 q4 k5 v5 k6 v6
v1 v2 v3
x1 x2 x3 x4 x5 x6
i went to school to learn
Mathematically: >
q√i kt if t ≤ i
Dk
ai,t =
−∞ if t > i
19
Have a careful look at what happens in the Google transformer diagram
for machine translation:
20
Cross-attention
he threw me </s>
hy het my gegooi <s> he threw me
x1 x2 x3 x4 y0 y1 y2 y3
21
Cross-attention
he threw me </s>
Keys and values: Encoder
Queries: Decoder
hy het my gegooi <s> he threw me
x1 x2 x3 x4 y0 y1 y2 y3
Have a look at the Google transformer diagram again.
22
Transformer
Figure from (Vaswani et al., 2017).
We haven’t spoken about the add & norm block:
• Residual connections
• Layer normalisation
23
Videos covered in this note
• Intuition behind self-attention (12 min)
• Attention recap (6 min)
• Self-attention details (13 min)
• Self-attention in matrix form (5 min)
• Positional encodings in transformers (19 min)
• The clock analogy for positional encodings (5 min)
• Multi-head attention (5 min)
• Masking the future in self-attention (5 min)
• Cross-attention (7 min)
• Transformer (4 min)
24
Acknowledgments
Christiaan Jacobs and Benjamin van Niekerk were instrumental in
helping me to start to understand self-attention and transformers.
This note relied heavily on content from:
• Chris Manning’s CS224N course at Stanford University, particu-
larly the transformer lecture by John Hewitt
• Lena Voita’s NLP course for you
Further reading
A. Goldie, “CS224N: Pretraining,” Stanford University, 2022.
A. Huang, S. Subramanian, J. Sum, K. Almubarak, and S. Biderman,
“The annotated transformer,” Harvard University, 2022.
References
T. Denk, “Linear relationships in the transformer’s positional encoding,”
2019.
J. Hewitt, “CS224N: Self-attention and transformers,” Stanford Uni-
versity, 2023.
A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N.
Gomez, L. Kaiser, and I. Polosukhin, “Attention is all you need,” in
NeurIPS, 2017.
L. Voita, “Sequence to sequence (seq2seq) and attention,” 2023.
A. Zhang, Z. C. Lipton, M. Li, and A. J. Smola, Dive into Deep
Learning, 2021.
25