The only time that the query and key matrices are used is to compute the attention scores. That is $v_i^T \cdot W_q^T W_k v_j$ But what is used is the matrix $W_q^T W_k$. Why not just replace $W_q^T W_k$ with a single matrix $W_{qv}$, and learn the matrix that is the product of W_q^T W_k instead of the matrices themselves? How does it help to have two matrices instead of one? And if it helps, why is that not done when applying matrices between neuron layers?
Chatgpt tells me that the reason is that it allows the model to learn a different representation for the query and key. But because they are just dotted together, it seems to me that you can just use the original embedding as the query with no loss of generality.
I would suggest looking into the math a little more. I think all of the matrices in the attention layer are a (linear) function of the input sequence. So the output of the attention layer is softmax of a quadratic of the input iirc
Can you show ordering equivariance of the single matrix with the two matrices?
This form of Attention much be equivariant with respect to token order, eg
attn(ABCD) == rot2(attn(rot2(ABCD))) == rot2(attn(CDAB))
I am using rot here for token rotation.
If I’m understanding your question correctly, it probably doesn’t make any differences computation wise. But if we have query dot key as one single input, then the attention layer would just have two inputs: 1.query dot key matrix; 2. value matrix. I think this would be a worse formulation thant the original paper altough they are the same computation wise. By allowing separate key and value matrices, the data flow is clearer. For example the Encoder-Decoder attention layer takes the result of Encoder block as key and value but the processed target sequence as value. This idea is very clear with the original attention layer formation.
It’s the same mathematically but not computation wise, the tokens are projected to a smaller dimension. The complexity is 2Nd whereas it’d be N² if you’d fuse the weight matrices.
This.
Something to add to the other great answers here - you can say something similar about head-specific matrices W_V and W_O - they always act together as well. In fact, Anthropic recommends thinking of W_OW_V and W_Q^TW_K as basic primitives in their transformer interpretability framework: https://transformer-circuits.pub/2021/framework/index.html
I don’t remember the ref but I browsed a theory paper at some point that did consider that representation (the product explicitly), possibly with something like nuclear norm regularization to keep the rank low.
On the Eleuther AI discord, someone once asked that question. And someone else replied that yeah, obviously having 1 matrix instead of 2 should be better in theory, but then, in practice, empirically, that makes things worse. Why? Noone knows.
These answers seem weird to me. Am I misunderstanding? Here’s the obvious-seeming answer to me:
You need two different matrices because you need an attention coefficient for every single pair of vectors.
If there are n tokens, then for the n’th token you need n-1 different attention coefficients (one for each token it attends). For the n-1’th token, you need n-2 different coefficients, and so on, until the 2nd vector which needs only one coefficient, and the first vector which needs zero (it can’t attend anything).
That’s ~n^2 coefficients in total. If you compute key and query vectors, then you only need 2n different vectors (one key and one query for each of the n vectors). If the key/query vectors are d-dimensional that’s 2dn numbers, which is still smaller than n^2 if the context size is bigger than the key/query dimension
So using separate vectors is more efficient and more scalable.
The other answers on this thread seem different, which is surprising to me since this answer feels very straightforward. If I’m missing something, I’d love an explanation
dammit all of the answers are fkin terrible. Looks like the ai bots took over or everyone in this subreddit has become braindead since the blackout.
You obviously don’t do W_q @ W_k. That’s totally stupid.
What transformers do is (x_i@W_q) @ (x_j@W_k) where x_i and x_j are two tokens in the sequence. This is an interaction operation. This can’t be precomputed. What you see noted in the papers is Q = x_i @ W_q, and K = x_j @ W_k.
(Transposes omitted for notational clarity, work that out yourself)
Your answer is also terrible. It does not answer his question.
Look at the top 2 replies to see correct interpretations of the question.
If you keep the matrices separate, you can control the rank of the learned weights.
Otherwise, the (single) matrix will be full rank.