We propose a technical exploration of In-Context Learning (ICL) for linear regression tasks in transformer architectures. Focusing on the article Transformers Learn In-Context by Gradient Descent by J. von Oswald et al., published in ICML 2023 last year, we provide detailed explanations and illustrations of the mechanisms involved. We also contribute novel analyses on ICL, discuss recent developments and we point to open questions in this area of research.
In-Context Learning (ICL) is the behavior first observed in Large Language Models (LLMs), whereby learning occurs from prompted data without modification of the weights of the model
Interestingly, it was around the release of GPT-2 and GPT-3 that researchers observed that an auto-regressive language model pre-trained on enough data with enough parameters was capable of performing arbitrary tasks without fine-tuning, by simply prompting the model with the task with few examples and letting it generate the output. In recent months, the research community has started to investigate the phenomenon of ICL in more details, and several papers have been published on the topic.
Specifically, since learning processes in biology and machine are often, if not always, understood in terms of iterative optimization, it is natural to ask what kind of iterative optimization is being realized during ICL, and how.
Though ICL is generally regarded as a phenomenon exhibited by LLMs, we now hasten to study it in a non-language, small-scale model that enables more control and where ICL can still be shown to emerge. This simpler situation is that of a transformer model trained to regress a set of numerical data points presented in the prompt, with data points generated from a distinct function for each prompt, but where all prompts sample a function from the same general class (i.e. linear) at train and at test time. We will see that to some extent, this simplification allows for a mathematical treatment of ICL.
The following figure gives a visual representation of the ICL setup we will consider in this blog post. The model is a generic transformer pre-trained to solve generic linear regression tasks. At inference time, we can give the model a prompt with a new linear regression task, and it is able to solve it with surprisingly good performance.
The objective of this blog post is to understand how ICL is possible, and to present in an interactive way what is known of its underlying mechanism. Specifically, we will analyze the results reported in the paper Transformers Learn In-Context by Gradient Descent by J. von Oswald et al. recently published in ICML 2023
First of all we need to agree on a mathematical formalization of in-context learning.
Before we start, let’s introduce some notation and color convention that will be used throughout the rest of the blog post. We will use the following colors to denote different quantities:
Vectors will be denoted with bold letters, e.g. \(\mba\), and matrices with bold capital letters, e.g. \(\mbA\). Additional notation will be introduced in-line when needed.
Formally, let’s define \(p(\mbx)\) as a probability distribution over inputs \(\mbx\in\cX\) and \(\cH\) a class of functions \(h: \cX \rightarrow \cY\). You can think of \(\cH\) as a set of functions that share some common properties, for example, the set of all linear functions, or the set of all functions that can be represented by a neural network with a given architecture. Also, let’s define \(p(h)\) as a probability measure over \(\cH\).
Following the terminology of the LLM community, let’s define a prompt \(P\) of length \(C\) as a sequence of \(2C+1\) points \((\mbx_0, h(\mbx_0), \ldots, \mbx_{C-1}, h(\mbx_{C-1}), \mbx_{\text{query}})\) where inputs (\(\mbx_i\) and \(\mbx_{\text{query}}\)) are independently and identically drawn from \(p(\mbx)\), and \(h\) is drawn from \(\cH\). In short we will also write \(P_C = \left[\{\mbx_i, h(\mbx_i)\}_{i=0}^{C-1}, \mbx_\text{query}\right]\).
For our setup, we will consider a linear regression problem, where the goal is to learn a linear function \(h_{\mbw}(\mbx) = \mbw^\top\mbx\), with \(\mbw\in\bbR^D\), from a set of in-context examples \(\{\mbx_i, \y_i\}_{i=0}^{C-1}\), where \(\mbx_i\in\bbR^D\) and \(\y_i\in\bbR\). So \(h_{\mbw} \in \cH\).
In order to better understand how the prompt is constructed starting from a regression task, let’s consider the following visual example:
The figure shows a visual representation of the construction of a single input prompt. In particular, we first sample a weight \(\mbw\) from the distribution \(p(\mbw)\), and then we sample \(C\) inputs \(\mbx_i\) from \(p(\mbx)\), where \(C\) is the fixed context size. Finally, we compute the corresponding outputs \(\y_i = \mbw^\top\mbx_i\). We consider \(p(\mbx) = \cU(-1, 1)\), where \(\cU\) is the uniform distribution, and \(p(\mbw) = \cN(\mbzero, \alpha^2\mbI)\), where \(\cN\) is a multivariate Gaussian distribution of dimension \(D\), with \(0\) mean and \(\alpha\) standard deviation.
Defining \(c=C+1\) and \(d=D+1\), where \(C\) is the context size and \(D\) is the input dimension, we can represent the input as a matrix \(\mbE\in\bbR^{d\times c}\) (also referred to as token embeddings or, simply, embeddings), where the first \(C\) columns represent the context inputs \(\mbx_i\) and output \(\y\) and the last column represents the query input \(\mbx_{\text{query}}\) with \(0\) padding.
To construct a batch of regression problems, we just repeat the above procedure \(N\) times with the fixed context size \(C\), where \(N\) is the size of the batch.
In this section we will briefly review the self-attention mechanism, which is the core component of the transformer architecture
Let \(\mbW^K, \mbW^Q \in \bbR^{d_k\times d}\), \(\mbW^V \in \bbR^{d_v\times d}\) and \(\mbW^P \in \bbR^{d \times d_v}\) the key, query, value and projection weight matrices respectively. Given an embedding \(\mbE\in\bbR^{d\times c}\), the softmax self-attention layer implements the following operation,
\[\begin{equation} \label{eq:softmax-self-attention} f_\text{attn} (\mbtheta_\text{attn}, \mbE) = \mbE + \mbW^P \mbW^V \mbE \sigma\left(\frac{(\mbW^K \mbE)^\top \mbW^Q \mbE}{\sqrt{d}}\right), \end{equation}\]with \(\mbtheta_\text{attn}=\{\mbW^K, \mbW^Q, \mbW^V, \mbW^P\}\), where for simplicity we will consider \(d_k=d_v=d\), and \(\sigma(\cdot)\) is the softmax function applied column-wise. It’s simple to verify that the output dimension of \(f_\text{attn}\) is the same as the input dimension. To simplify further, we can also define the value, key and query matrices as \(\mbV = \mbW^V\mbE\), \(\mbK = \mbW^K\mbE\), \(\mbQ = \mbW^Q\mbE\), respectively.
Once the dataset is created, we can train the model using the following objective:
\[\begin{equation} \label{eq:pre-train-loss-expectation} \cL(\mbtheta) = \mathbb{E}\left\|f\left(\mbtheta, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right) - \y_{\text{query}}\right\|^2, \end{equation}\]where the expectation is taken over \(p(\mbx)\) and \(p(\mbw)\), with \(h_{\mbw}(\mbx) = \mbw^\top\mbx\). Note that the output of the model is a sequence of \(C+1\) values, i.e. same as the input prompt, and the loss is computed only on the last value of the sequence, which corresponds to the predicted query output \(\widehat\y_{\text{query}}\). Specifically, for reading out just the prediction for \(\mbx_{\text{query}}\), we multiply again by \(-1\) this last value. Note that this choice is completely transparent during model training, as it is equivalent to simply changing the sign of a few elements in the projection weight matrix \(\mbW^P\). The reason for this will be clear in the following sections. At each training iteration, we replace the expectation with an empirical average over a batch of \(N\) regression tasks, each made of a different set of context points \(\{\mbx_i^{(n)}, \y_i^{(n)}\}_{i=0}^{C-1}\), and a query input/target pain, \(\mbx^{(n)}_\text{query}\) and \(\y^{(n)}_{\text{query}}\), respectively. Note that because of the on-line creation of the dataset, during training the model will never see the same regression task twice.
With all the preliminaries and notations in place, we can now start to analyze some results regarding the ability of transformers to learn linear functions in-context. One of the first papers that studied the ability of transformers to learn linear functions in-context is What Can Transformers Learn In-Context? A Case Study of Simple Function Classes by S. Garg et al
In the figure below, we report the in-context test loss (as defined in Equation \eqref{eq:in-context-test-loss}) for each model configuration, for various context sizes \(C\), from 2 to 100.
The experiment above shows that the test loss diminishes for larger context sizes, and also as the number of layers increases. These two main effects are clearly expected, as consequences of more data points and more compute, respectively, and they replicate the findings of Garg et al
From this point, we will depart from the classic softmax self-attention layer, and restrict our study to a linear self-attention layer, which is the setting considered in the paper of J. von Oswald et al
A linear self-attention updates embeddings \(\mbE\) as follows:
\[\begin{equation} f_\text{linattn} (\mbtheta_\text{linattn}, \mbE) = \mbE + \frac{\mbW^P \mbV\left(\mbK^\top \mbQ \right)}{\sqrt{d}}, \end{equation}\]with \(\mbV, \mbK, \mbQ\) being the value, key and query defined right after Equation \eqref{eq:softmax-self-attention}.
Now, to analyze if a linear self-attention layer is sufficient to learn linear functions in-context, we can use the same experimental setup as before, but replacing the softmax self-attention layer with a linear self-attention layer.
Additionally, we also strip down the transformer to its bare minimum, i.e. we remove the normalization, the embedding layer, the feed-forward layer, and only use a single head. The only remaining component is the linear self-attention layer. Therefore, in the following we use the term “linear transformer” to refer to this simplified model.
We test the linear transformer on the same dataset setup as before, and we will use the same number of layers as before, i.e. 1, 2, 3, 4, 5.
From the previous section we have seen that a linear self-attention layer is sufficient to learn linear functions in-context. In this section we will try to understand why this is the case, starting from a review of least-squares regression and gradient descent.
In this section, we establish an important connection that will be fundamental to understand the mechanism behind ICL with linear self-attention. To do so we need to start from a simple linear regression problem, and we will show that we can achieve the same loss after one gradient step by changing the inputs and the targets, and keeping the weights fixed.
The loss for a linear regression problem is defined as: \(\begin{equation} \label{eq:linear-regression-loss} \cL_{\text{lin}}\left(\mbw, \{\mbx_i, {\y}_i\}_{i=0}^{C-1}\right) = \frac 1 {2C} \sum_{i=0}^{C-1} (\mbw^\top\mbx_i - \y_i)^2 \end{equation}\)
where \(\mbw\in\bbR^D\), \(\mbx_i\in\bbR^D\) and \(\y_i\in\bbR\). With a given learning rate \(\eta\), the gradient descent update is \(\mbw \leftarrow \mbw - \Delta \mbw\), where \(\begin{equation} \label{eq:linear-regression-gd-gradient} \Delta \mbw = \eta \nabla_{\mbw} \cL_{\text{lin}}\left(\mbw, \{\mbx_i, {\y}_i\}_{i=0}^{C-1}\right) = \frac{\eta}{C} \sum_{i=0}^{C-1} \left(\mbw^\top\mbx_i - \y_i\right)\mbx_i \end{equation}\) The corresponding loss (after the update) is: \(\begin{equation} \label{eq:linear-regression-loss-after-gd} \cL_{\text{lin}}\left(\mbw - \Delta \mbw, \{\mbx_i, {\y}_i\}_{i=0}^{C-1}\right) = \frac 1 {2C} \sum_{i=0}^{C-1} \left(\mbw^\top\mbx_i - \y_i - \Delta \mbw^\top\mbx_i\right)^2 \end{equation}\)
It is trivial to see that if we now define \(\widehat{\mbx}_i = \mbx_i\) and \(\widehat{\y}_i = \y_i + \Delta \mbw^\top\mbx_i\), we can compute Equation \eqref{eq:linear-regression-loss} with the new inputs and targets, i.e. \(\cL_{\text{lin}}(\mbw, \{\widehat{\mbx}_i, \widehat{\y}_i\}_{i=0}^{C-1})\), which is the same as the loss after the gradient descent update (Equation \eqref{eq:linear-regression-loss-after-gd}).
As we just saw, the starting intuition is that we can build a gradient step on the linear regression loss by manipulating the inputs and the targets. This is the key insight of Oswald et al.
Before stating the main result, recall the definitions of value, key and query as \(\mbV = \mbW^V\mbE\), \(\mbK = \mbW^K\mbE\), and \(\mbq_j = \mbW^Q\mbe_j\).
Main result: Given a 1-head linear attention layer and the tokens \(\mbe_j = (\mbx_j, \y_j)\), for \(j=0,\ldots,C-1\), we can construct key, query and value matrices \(\mbW^K, \mbW^Q, \mbW^V\) as well as the projection matrix \(\mbW^P\) such that a transformer step on every token \(\mbe_j \leftarrow (\mbx_i, \y_{i}) + \mbW^{P} \mbV \mbK^{T}\mbq_{j}\) is identical to the gradient-induced dynamics \(\mbe_j \leftarrow (\mbx_j, \y_j) + (0, -\Delta \mbW \mbx_j)\). For the query data \((\mbx_{\text{query}}, \y_{\text{query}})\), the dynamics are identical.
For notation, we will identify with \(\mbtheta_\text{GD}\) the set of parameters of the linear transformer that implements a gradient descent step.
Nonetheless, we can construct a linear self-attention layer that implements a gradient descent step and a possible construction is in block form, as follows.
\[\begin{align} \mbW^K = \mbW^Q = \left(\begin{array}{@{}c c@{}} \mbI_D & 0 \\ 0 & 0 \end{array}\right) \end{align}\]with \(\mbI_D\) the identity matrix of size \(D\), and
\[\begin{align} \mbW^V = \left(\begin{array}{@{}c c@{}} 0 & 0 \\ \mbw_0^\top & -1 \end{array} \right) \end{align}\]with \(\mbw_0 \in \bbR^{D}\) the weight vector of the linear model and \(\mbW^P = \frac{\eta}{C}\mbI_{d}\) with identity matrix of size \(d\).
If you are interested in the proof of construction for the GD-equivalent transformer, you can find it in the following collapsible section.
To verify this, first remember that if \(\mbA\) is a matrix of size \(N\times M\) and \(\mbB\) is a matrix of size \(M\times P\),
\[\begin{align} \mbA\mbB = \sum_{i=1}^M \mba_i\otimes\mbb_{,i} \end{align}\]where \(\mba_i \in \bbR^{N}\) is the \(i\)-th column of \(\mbA\), \(\mbb_{,i} \in \bbR^{P}\) is the \(i\)-th row of \(\mbB\), and \(\otimes\) is the outer product between two vectors.
It is easy to verify that with this construction we obtain the following dynamics
\[\begin{align} \left(\begin{array}{@{}c@{}} \mbx_j\\ \y_j \end{array}\right) \leftarrow & \left(\begin{array}{@{}c@{}} \mbx_j\\ \y_j \end{array}\right) + \mbW^{P} \mbV \mbK^{T}\mbq_{j} = \mbe_j + \frac{\eta}{C} \sum_{i={0}}^{C-1} \left(\begin{array}{@{}c c@{}} 0 & 0 \\ \mbw_0 & -1 \end{array} \right) \left(\begin{array}{@{}c@{}} \mbx_i\\ \y_i \end{array}\right) \otimes \left( \left(\begin{array}{@{}c c@{}} \mbI_D & 0 \\ 0 & 0 \end{array}\right) \left(\begin{array}{@{}c@{}} \mbx_i\\ \y_i \end{array}\right) \right) \left(\begin{array}{@{}c c@{}} \mbI_D & 0 \\ 0 & 0 \end{array}\right) \left(\begin{array}{@{}c@{}} \mbx_j\\ \y_j \end{array}\right)\\ &= \left(\begin{array}{@{}c@{}} \mbx_j\\ \y_j \end{array}\right) + \frac{\eta}{C} \sum_{i={0}}^{C-1} \left(\begin{array}{@{}c@{}} 0\\ \mbw_0^\top \mbx_i - \y_i \end{array}\right) \otimes \left(\begin{array}{@{}c@{}} \mbx_i\\ 0 \end{array}\right) \left(\begin{array}{@{}c@{}} \mbx_j\\ 0 \end{array}\right) = \left(\begin{array}{@{}c@{}} \mbx_j\\ \y_j \end{array}\right) + \left(\begin{array}{@{}c@{}} 0\\ - \frac{\eta}{C}\sum_{i=0}^{C-1} \left( \left(\mbw_0^\top\mbx_i - \y_i\right)\mbx_i\right)^\top \mbx_j \end{array}\right). \end{align}\]Note that the update for the query token \((\mbx_{\text{query}}, \textcolor{output}{0})\) is identical to the update for the context tokens \((\mbx_j, \y_j)\) for \(j=0,\ldots,C-1\).
Now let’s do some experiments to verify the theoretical results. We will work within the same experimental setup as before with the same dataset construction, training procedure and testing procedure. In this first section, we consider a linear transformer with a single layer, and the transformer built as described in the previous section (the GD-equivalent transformer), i.e. with a linear self-attention layer that implements a gradient descent step.
We now study the evolution of the test loss of a linear transformer during training \(\cL(\mbtheta)\), and compare it to the loss of a transformer implementing a gradient descent step \(\cL(\mbtheta_\text{GD})\).
Although an empirical proof of such a functional equivalence would require to check the outputs for all possible test samples, we can try to gather more evidence by considering more closely the computations that unfold in the linear transformer during one pass.
To better understand the dynamics of the linear transformer, we now study the evolution of a few metrics during training (the L2 error for predictions, the L2 error for gradients and the cosine similarity between models).
The metrics introduced above are defined as follows:
L2 error (predictions) measures the difference between the predictions of the linear transformer and the predictions of the transformer implementing a gradient descent step and it is defined as \(\left\|f\left(\mbtheta, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right) - f\left(\mbtheta_\text{GD}, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right) \right\|^2\);
L2 error (gradients w.r.t. inputs) measures the difference between the gradients of the linear transformer and the gradients of the transformer implementing a gradient descent step and it is defined as \(\left\|\nabla_{\mbx_\text{query}} f\left(\mbtheta, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right) - \nabla_{\mbx_\text{query}} f\left(\mbtheta_\text{GD}, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right) \right\|^2\);
Model cosine similarity (gradients w.r.t. inputs) measures the cosine similarity between the gradients of the linear transformer and the gradients of the transformer implementing a gradient descent step and it is defined as \(\cos\left(\nabla_{\mbx_\text{query}} f\left(\mbtheta, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right), \nabla_{\mbx_\text{query}} f\left(\mbtheta_\text{GD}, \left[\{\mbx_i, \y_i\}_{i=0}^{C-1}, \mbx_\text{query}\right]\right)\right)\).
From this figure, we see that the predictions of the linear transformer converge to the predictions of the GD-transformer, and the gradients of the linear transformer converge to the gradients of the GD-transformer. Notably, this is true for all context sizes, though the convergence is faster for larger \(C\).
As a final visualization, we can also look at the evolution of the gradients of the linear transformer during training, as shown in the figure below. In this animation, we take six different regression tasks and we plot the gradients of the linear transformer during training and the exact gradients of the least-squares regression loss.
To reiterate, the loss landscape visualized is the least-squares regression loss and each task is a different linear regression problem with a different loss landscape. Once more, this is a visualization that the linear transformer is not learning a single regression model, but it is learning to solve a linear regression problem.
Next, we study the effect of the GD learning rate on the test loss of the GD-equivalent transformer. We believe this is an important point of discussion which was covered only briefly in the paper.
Indeed, this is the same procedure we have used to find the optimal GD learning rate for our previous experiments. We now show what happens if we use a different GD learning rate than the one found with line search. In the following experiment, we visualize this behavior, by plotting the metrics described above for different values of the GD learning rate.
It turns out that having a line search to find the best GD learning rate is not necessary.
The analytical solution is provided below with its derivation reported in the collapsible section immediately following.
We are interested in finding the optimal learning rate for the GD-transformer, which by construction (see main Proposition), is equivalent to finding the optimal GD learning rate for the least-squares regression problem. Consequently, the analysis can be constructed from the least-squares regression problem \eqref{eq:linear-regression-loss}.
Recall the GD update of the least-squares regression in \eqref{eq:linear-regression-gd-gradient} without taking into account of the learning rate. That is,
\[\begin{equation} \label{eq:linear-regression-gd-gradient-no-lr} \Delta \mbw = \nabla_{\mbw} \cL_{\text{lin}}\left(\mbw, \{\mbx_i, \y_i\}_{i=0}^{C-1}\right) = \frac{1}{C} \sum_{i=0}^{C-1} \left(\mbw^\top\mbx_i - \y_i\right)\mbx_i. \end{equation}\]Now we consider the test loss of the least-squares regression defined as
\[\begin{equation} \cL_\mathrm{lin, te}(\{\mbw^{(n)}\}_{n=0}^{N-1}) = \frac{1}{N} \sum_{n=0}^{N-1} ((\mbx^{(n)}_\text{query})^\top \mbw^{(n)} - \y^{(n)}_\text{query})^2, \end{equation}\]where \(N\) is the number of the queries, which is the same number of the regression tasks of the in-context test loss dataset. Similar to \eqref{eq:linear-regression-loss-after-gd}, after one step of the GD update \eqref{eq:linear-regression-gd-gradient-no-lr}, the corresponding test loss becomes
\[\begin{align} &\quad \ \ \cL_\mathrm{lin, te}(\{\mbw^{(n)} - \eta \Delta \mbw^{(n)}\}_{n=0}^{N-1}) \nonumber \\ &= \frac{1}{N} \sum_{n=0}^{N-1} \left((\mbx^{(n)}_\text{query})^\top (\mbw^{(n)} - \eta \Delta \mbw^{(n)}) - \y^{(n)}_\text{query}\right)^2 \nonumber \\ &= \frac{1}{N} \sum_{n=0}^{N-1} \left((\mbx^{(n)}_\text{query})^\top \mbw^{(n)} - \y^{(n)}_\text{query} - \eta (\mbx^{(n)}_\text{query})^\top \Delta \mbw^{(n)} \right)^2 \nonumber \\ &= \frac{\eta^2}{N} \sum_{n=0}^{N-1} ((\mbx^{(n)}_\text{query})^\top \Delta \mbw^{(n)})^2 + \cL_\mathrm{lin, te}(\{\mbw^{(n)}\}_{n=0}^{N-1}) \nonumber \\ &\quad \ - \frac{2\eta}{N} \sum_{n=0}^{N-1} ((\mbx^{(n)}_\text{query})^\top \mbw^{(n)} - \y^{(n)}_\text{query})(\mbx^{(n)}_\text{query})^\top \Delta \mbw^{(n)}. \label{eq:loss_query_W1} \end{align}\]One can choose the optimum learning rate \(\eta^*\) such that \(\cL_\mathrm{lin, te}(\{\mbw^{(n)} - \eta \Delta \mbw^{(n)}\}_{n=0}^{N-1})\) achieves its minimum with respect to the learning rate \(\eta\). That is,
\[\begin{align} \eta^* \in \arg\min_{\eta > 0} \cL_\mathrm{lin, te}(\{\mbw^{(n)} - \eta \Delta \mbw^{(n)}\}_{n=0}^{N-1}). \end{align}\]To obtain \(\eta^*\), it suffices to solve
\(\begin{align} \nabla_\eta \cL_\mathrm{lin, te}(\{\mbw^{(n)} - \eta \Delta \mbw^{(n)}\}_{n=0}^{N-1}) = 0. \end{align}\) From \eqref{eq:loss_query_W1} and plugging \(\Delta w^{(n)}\) in \eqref{eq:linear-regression-gd-gradient-no-lr}, we obtain \(\begin{align} \eta^* &= \frac{\sum_{n=0}^{N-1} ((\mbx^{(n)}_\text{query})^\top \mbw^{(n)} - \y^{(n)}_\text{query})(\mbx^{(n)}_\text{query})^\top \Delta \mbw^{(n)} } {\sum_{n=0}^{N-1} ((\mbx^{(n)}_\text{query})^\top \Delta \mbw^{(n)})^2} \nonumber \\ &= C \frac{\sum_{n=0}^{N-1} ((\mbx^{(n)}_\text{query})^\top \mbw^{(n)} - \y^{(n)}_\text{query}) \sum_{i=0}^{C-1} ((\mbw^{(n)})^\top \mbx_i^{(n)} - \y_i^{(n)})(\mbx_i^{(n)})^\top \mbx^{(n)}_\text{query}} {\sum_{n=0}^{N-1} \left( \sum_{i=0}^{C-1} ((\mbw^{(n)})^\top \mbx_i^{(n)} - \y_i^{(n)})(\mbx_i^{(n)})^\top \mbx^{(n)}_\text{query} \right)^2}. \end{align}\) Finally, for the initialization \(\mbw^{(n)} = 0\) for \(n = 0, \ldots, N-1\), the optimal learning rate can be simplified to be \(\begin{align} \eta^* = C \frac{\sum_{n=1}^{N-1} \y^{(n)}_\text{query} \left(\sum_{i=0}^{C-1}\left( \y^{(n)}_i{\left(\mbx^{(n)}_i\right)}^\top \mbx_\text{query}^{(n)}\right)\right) }{\sum_{n=1}^{N-1} \left(\sum_{i=0}^{C-1}\left(\y^{(n)}_i {\left(\mbx^{(n)}_i\right)}^\top \mbx_\text{query}^{(n)}\right)\right)^2}. \end{align}\)
This derivation of the optimal GD learning rate \(\eta^*\) agrees well with the line search procedure (up to the numerical precision of the line search procedure itself). While this is expected, let’s take a moment to understand why this is the case.
The analytical solution is obtained starting from the linear regression loss, while the line search procedure using the loss \(\cL(\mbtheta_\text{GD})\) defined in Equation \eqref{eq:pre-train-loss-expectation}. However, the two losses are equivalent by construction, hence the two procedures are equivalent.
Because the construction of the GD transformer is not unique, it’s not easy to see the effect of the GD learning rate once we compare it with the trained linear transformer. Recall that due to its parametrization, the linear transformer does not have an explicit \(\eta\) parameter, which it can be absorbed in any of the weight matrices in the linear self-attention layer. Yet, the linear transformer converges to the exact same loss of the GD-transformer for the optimal GD learning rate \(\eta^*\). This is expected because fundamentally the loss function used for the line search and the one used for the analytical solution is equivalent to the loss in Equation \eqref{eq:pre-train-loss-expectation} used during the transformer training.
Said differently, what we did in two steps for the GD-transformer (first build the \(\mbW^K, \mbW^Q, \mbW^V\) matrices, then find the optimal GD learning rate) is done implicitly during the training of the linear transformer.
The following table summarizes the three different procedures we have discussed so far.
Loss function | GD learning rate | |
---|---|---|
Least-squares regression | \(\cL_\text{lin}(\mbw-\Delta \mbw)\) | Explicit \(\eta^*\) by analytical solution |
GD-transformer | \(\cL(\mbtheta_\text{GD})\) | Explicit \(\eta^*\) by line search |
Linear transformer | \(\cL(\mbtheta)\) | Implicit \(\eta^*\) by training \(\mbtheta\) |
Finally, one comment on the computational complexity of the two procedures. It doesn’t come as a surprise that the analytical solution is faster to compute than the line search: the line search requires on average 10 seconds to find the optimal GD learning rate, while the analytical solution requires only 10 milliseconds (both with JAX’s JIT compilation turned on, run on the same GPU).
It is only natural to ask if the same behavior is observed for a linear transformer with multiple layers. In particular, if we take a trained linear transformer with a single layer (which we now know it implements a gradient descent step) and we repeat the same layer update multiple times recursively, will we observe the same behavior?
As we now show in the following experiment, the answer is no. In fact, the test loss for both the linear transformer and the transformer implementing a gradient descent step diverges as we increase the number of layers.
To stabilize this behavior, we use a dampening factor \(\lambda\), which is a scalar in \([0, 1]\), and we update the linear transformer as follows:
\[\begin{equation} \label{eq:linear-transformer-update} \mbE^{(l+1)} = \mbE^{(l)} + \lambda \mbW^P \mbV\left(\mbK^\top \mbQ \right), \end{equation}\]where \(\mbE^{(l)}\) is the embedding matrix at layer \(l\), and \(\mbW^P, \mbV, \mbK, \mbQ\) are the projection, value, key and query matrices as defined before. Effectively, this is equivalent to applying a gradient descent step with scaled learning rate.
Note that in the original paper, the authors suggest that a dampening factor of \(\lambda=0.75\) is generally sufficient to obtain the same behavior as a single layer linear transformer. As we can see from the figure above, in our investigations we do not find this to be the case. In our experiments, we see that we need at least \(\lambda=0.70\) to obtain the same behavior as a single layer linear transformer, which suggests that the effect of the dampening factor can vary.
Transformers are not the only architecture that can sequence-to-sequence models
Indeed, from a modeling perspective, nothing prevents us from using a LSTM to implement in-context learning for regression tasks. In fact, we can use the same experimental setup as before, but replacing the transformer with a LSTM. The main architectural difference between a LSTM and a transformer is that LSTM layers are by-design causal, i.e. they can only attend to previous tokens in the sequence, while transformers can attend to any token in the sequence. While for some tasks where order matters, like language modeling, this is a desirable property
In this first experiment, we analyze the performance of the uni-directional and the bi-directional LSTM to learn linear functions in-context. Note that because of the intrinsic non-linear nature of the LSTM layers, we cannot manually construct a LSTM that implements a gradient descent step, as we did for the transformer. Nonetheless, we can still compare the LSTMs with the GD-equivalent transformer (which we now know it implements a gradient descent step on the least-squares regression loss).
In this figure we can see that a single layer LSTM is not sufficient to learn linear functions in-context. For the uni-directional LSTM, we see that the test loss is always higher than the test loss of the transformer implementing a gradient descent step, even if we increase the number of layers. On the contrary, for the bi-directional LSTM, we see that the test loss approaches that of the GD-equivalent transformer as we increase the number of layers.
The poor performance of the uni-directional LSTM is not surprising. Additional evidence is provided in the figure below, where, as we did for the transformer, we plot the L2 error (predictions), the L2 error (gradients w.r.t. inputs) and the model cosine similarity (gradients w.r.t. inputs) comparing the LSTM with the GD-equivalent transformer.
Regardless of the number of layers, we see that the uni-directional LSTM is not implementing a gradient descent step, as the L2 error (predictions) and the L2 error (gradients w.r.t. inputs) do not converge to 0, and the model cosine similarity (gradients w.r.t. inputs) remains well below 1. The picture changes for the bi-directional LSTM, as we can see in the figure below.
While for a single layer, we can comfortably say that also the bi-directional LSTM is not equivalent to a GD step, for 2 or more layers we cannot reject the hypothesis that the bi-directional LSTM is equivalent to a GD step (use the slider to change the number of layers in Figure 14-16). Note that if we compare this result with Figure 10, while we don’t see exactly the same behavior (e.g. cosine similarity a bit lower than 1), it is still remarkably similar. This is not a conclusive result but it is interesting to see that the bi-directional LSTM can learn linear functions in-context similarly to a transformer implementing a gradient descent step.
In this blog post, we have presented a series of experiments to understand the mechanistic behavior of transformers and self-attention layers through the lens of optimization theory. In particular, we analyze the results of the paper Transformers Learn In-Context by Gradient Descent
The results presented in this blog post, while confirming the main findings of the original paper, also raise a number of questions and suggest possible future research directions.
To reiterate, what we have done so far is to try to understand the behavior of transformers and self-attention layers through the lens of optimization theory. This is the common approach in the literature, including very recent additions
On the other hand, nothing prevents us from working in the opposite direction, i.e. to start from specific learning algorithms and try to design neural networks that implement them. From an alignment perspective, for example, this is desirable because it allows us to start by designing objective functions and learning algorithms that are more interpretable and more aligned with our objectives, rather than starting from a black-box neural network and trying to understand its behavior. In this quest, the developing theory of mesa-optimization
Finally, we want to highlight that the main results shown in this blog post are consequences of the simplified hypothesis and the experimental setup we have considered (linear functions, least-squares regression loss, linear self-attention layers). In an equally recent paper
PLACEHOLDER FOR ACADEMIC ATTRIBUTION
BibTeX citation
PLACEHOLDER FOR BIBTEX