One of the most popular methods for continual learning with deep neural networks is Elastic Weight Consolidation (EWC), which involves computing the Fisher Information. The exact way in which the Fisher Information is computed is however rarely described, and multiple different implementations for it can be found online. This blog post discusses and empirically compares several often-used implementations, which highlights that many currently reported results for EWC could likely be improved by changing the way the Fisher Information is computed.
Continual learning is a rapidly growing subfield of deep learning devoted to enabling neural networks to incrementally learn new tasks, domains or classes while not forgetting previously learned ones. Such continual learning is crucial for addressing real-world problems where data are constantly changing, such as in healthcare, autonomous driving or robotics. Unfortunately, continual learning is challenging for deep neural networks, mainly due to their tendency to forget previously acquired skills when learning something new.
Elastic Weight Consolidation (EWC)
The Fisher Information matrix is also frequently used in the optimization literature. There, several years ago, Kunstner and colleagues
Before diving into EWC and the computation of the Fisher Information, let me introduce the continual learning problem by means of a simple example. Say, we have a deep neural network model \(f_{\boldsymbol{\theta}}\), parameterized by weight vector \(\boldsymbol{\theta}\). This model has already been trained on a first task (or a first set of tasks, this example can work recursively), by optimizing a loss function \(\ell_{\text{old}}(\boldsymbol{\theta})\) on training data \(D_{\text{old}}\sim\mathcal{D}_{\text{old}}\). This resulted in weights \(\hat{\boldsymbol{\theta}}_{\text{old}}\). We then wish to continue training this model on a new task, by optimizing a loss function \(\ell_{\text{new}}(\boldsymbol{\theta})\) on training data \(D_{\text{new}}\sim\mathcal{D}_{\text{new}}\), in such a way that the model maintains, or possibly improves, its performance on the previously learned task(s). Unfortunately, as has been thoroughly described in the continual learning literature, if the model is continued to be trained on the new data in the standard way (i.e., optimizing \(\ell_{\text{new}}(\boldsymbol{\theta})\) with stochastic gradient descent), the typical result is catastrophic forgetting
In this blog post, similar to most of the deep learning work on continual learning, the focus is on supervised learning. Each data point thus consists of an input \(\boldsymbol{x}\) and a corresponding output \(y\), and our deep neural network models the conditional distribution \(p_{\boldsymbol{\theta}}(y|\boldsymbol{x})\).
Now we are ready to take a detailed look at EWC. We start by formally defining this method. When training on a new task, to prevent catastrophic forgetting, rather than optimizing only the loss on the new task \(\ell_{\text{new}}(\boldsymbol{\theta})\), EWC adds an extra term to the loss that involves the Fisher Information:
\[\ell_{\text{EWC}}(\boldsymbol{\theta})= \ell_{\text{new}}(\boldsymbol{\theta})+\frac{\lambda}{2}\sum_{i=1}^{N_{\text{params}}} F_{\text{old}}^{i,i}(\theta^{i}-\hat{\theta}_{\text{old}}^{i})^2\]In this expression, \(N_{\text{params}}\) is the number of parameters in the model, \(\theta^{i}\) is the value of parameter \(i\) (i.e., the \(i^{\text{th}}\) element of weight vector \(\boldsymbol{\theta}\)), \(F_{\text{old}}^{i,i}\) is the \(i^{\text{th}}\) diagonal element of the model’s Fisher Information matrix on the old data, and \(\lambda\) is a hyperparameter that sets the relative importance of the new task compared to the old one(s).
EWC can be motivated from two perspectives, each of which I discuss next.
Loosely inspired by neuroscience theories of how synapses in the brain critical for previously learned skills are protected from overwriting during subsequent learning
A second motivation for EWC comes from a Bayesian perspective, because EWC can also be interpreted as performing approximate Bayesian inference on the parameters of the neural network. For this we need to take a probabilistic perspective, meaning that we view the network parameters \(\boldsymbol{\theta}\) as a random variable over which we want to learn a distribution. Then, when learning a new task, the idea behind EWC is to use the posterior distribution \(p(\boldsymbol{\theta}|D_{\text{old}})\) that was found after training on the old task(s), as the prior distribution when training on the new task. To make this procedure tractable, the Laplace approximation is used, meaning that the distribution \(p(\boldsymbol{\theta}|D_{\text{old}})\) is approximated as a Gaussian centered around \(\hat{\boldsymbol{\theta}}_{\text{old}}\) and with the Fisher Information \(F_{\text{old}}\) as precision matrix. To avoid letting the computational costs become too high, EWC sets the diagonal elements of \(F_{\text{old}}\) to zero.
EWC thus involves computing the diagonal elements of the network’s Fisher Information on the old data. Following the definitions and notation in Martens
In this definition, there are two expectations: (1) an outer expectation over \(\mathcal{D}_{\text{old}}\), which is the (theoretical) input distribution of the old data; and (2) an inner expectation over \(p_{\hat{\boldsymbol{\theta}}_{\text{old}}}(y|\boldsymbol{x})\), which is the conditional distribution of \(y\) given \(\boldsymbol{x}\) defined by the neural network after training on the old data. The different ways of computing the Fisher Information that can be found in the continual learning literature differ in how these two expectations are computed or approximated.
If computational costs are not an issue, the outer expectation in the definition of \(F_{\text{old}}^{i,i}\) can be estimated by averaging over all available training data \(D_{\text{old}}\), while — in the case of a classification problem — the inner expectation can be calculated for each training sample exactly:
\[F_{\text{old, EXACT}}^{i,i} = \frac{1}{|D_{\text{old}}|} \sum_{\boldsymbol{x}\in D_{\text{old}}} \left( \sum_{y=1}^{N_{\text{classes}}} p_{\hat{\boldsymbol{\theta}}_{\text{old}}}\left(y|\boldsymbol{x}\right) \left( \left. \frac{\delta\log p_{\boldsymbol{\theta}}\left(y|\boldsymbol{x}\right)}{\delta\theta^i} \right\rvert_{\boldsymbol{\theta}=\hat{\boldsymbol{\theta}}_{\text{old}}} \right)^2 \right)\]I refer to this option as EXACT, because for each sample in \(D_{\text{old}}\), the diagonal elements of the Fisher Information are computed exactly. I am not aware of many implementations of EWC that use this way of computing the Fisher Information, but one example can be found in ref
One way to reduce the costs of computing \(F^{i,i}_{\text{old}}\) is by estimating the outer expectation using only a subset of the old training data:
\[F_{\text{old, EXACT}(n)}^{i,i} = \frac{1}{n} \sum_{\boldsymbol{x}\in S_{D_{\text{old}}}^{(n)}} \left( \sum_{y=1}^{N_{\text{classes}}} p_{\hat{\boldsymbol{\theta}}_{\text{old}}}\left(y|\boldsymbol{x}\right) \left( \left. \frac{\delta\log p_{\boldsymbol{\theta}}\left(y|\boldsymbol{x}\right)}{\delta\theta^i} \right\rvert_{\boldsymbol{\theta}=\hat{\boldsymbol{\theta}}_{\text{old}}} \right)^2 \right)\]whereby \(S_{D_{\text{old}}}^{(n)}\) is a set of \(n\) random samples from \(D_{\text{old}}\). Although this seems a natural way to reduce the computational costs of computing the Fisher Information, I am aware of only one study
Another way to make the computation of \(F^{i,i}_{\text{old}}\) less costly is by computing the squared gradient not for all possible classes, but only for a single class per training sample. This means that the inner expectation in the definition of \(F^{i,i}_{\text{old}}\) is no longer computed exactly. To maintain an unbiased estimate of the inner expectation, Monte Carlo sampling can be used. That is, for each given training sample \(\boldsymbol{x}\), the class for which to compute the squared gradient can be selected by sampling from \(p_{\hat{\boldsymbol{\theta}}_{\text{old}}}\left(.|\boldsymbol{x}\right)\). This gives:
\[F_{\text{old, SAMPLE}}^{i,i} = \frac{1}{|D_{\text{old}}|} \sum_{\boldsymbol{x}\in D_{\text{old}}} \left( \left. \frac{\delta\log p_{\boldsymbol{\theta}}\left(c_{\boldsymbol{x}}|\boldsymbol{x}\right)}{\delta\theta^i} \right\rvert_{\boldsymbol{\theta}=\hat{\boldsymbol{\theta}}_{\text{old}}} \right)^2\]whereby, independently for each \(\boldsymbol{x}\), \(c_{\boldsymbol{x}}\) is randomly sampled from \(p_{\hat{\boldsymbol{\theta}}_{\text{old}}}(.|\boldsymbol{x})\). I refer to this option as SAMPLE. This way of unbiasedly estimating the Fisher Information has been used in the implementation of EWC in refs
Another option is to compute the squared gradient only for each sample’s ground-truth class:
\[F_{\text{old, EMPIRICAL}}^{i,i} = \frac{1}{|D_{\text{old}}|} \sum_{\left(\boldsymbol{x},y\right)\in D_{\text{old}}} \left( \left. \frac{\delta\log p_{\boldsymbol{\theta}}\left(y|\boldsymbol{x}\right)}{\delta\theta^i} \right\rvert_{\boldsymbol{\theta}=\hat{\boldsymbol{\theta}}_{\text{old}}} \right)^2\]Computed this way, \(F_{\text{old}}\) corresponds to the “empirical” Fisher Information matrix
The last option that we consider has probably come about thanks to a feature of PyTorch. Note that all of the above ways of computing \(F^{i,i}_{\text{old}}\) require access to the gradients of the individual data points, as the gradients need to be squared before being summed. However, batch-wise operations in PyTorch only allow access to the aggregated gradients, not to the individual, unaggregated gradients. In PyTorch, the above ways of computing \(F^{i,i}_{\text{old}}\) could therefore only be implemented with mini-batches of size one. Perhaps in an attempt to gain efficiency, several implementations of EWC can be found on Github that compute \(F^{i,i}_{\text{old}}\) by squaring the aggregated gradients of mini-batches of size larger than one. Indeed, popular continual learning libraries such as Avalanche
whereby \(D_{\text{old}}^{(b)}\) is a batched version of the old training data $D_{\text{old}}$, so that the elements of \(D_{\text{old}}^{(b)}\) are mini-batches with \(b\) training samples. (And \(|D_{\text{old}}^{(b)}|\) is the number of mini-batches, not the number of training samples.) Below, we will explore this option using $b=128$, referring to it as BATCHED (b=128).
Now, let us empirically compare the performance of EWC with these various ways of computing the Fisher Information. To do so, I use two relatively simple, often used continual learning benchmarks: Split MNIST and Split CIFAR-10. For these benchmarks, the original MNIST or CIFAR-10 dataset is split up into five tasks with two classes per task. Both benchmarks are performed according to the task-incremental learning scenario, using a separate softmax output layer for each task. For Split MNIST, following refs
For the experiments on Split MNIST, the results are shown below in a graph and a table:
Accuracy (in %) | Train time (in S) | |
---|---|---|
EXACT | 99.12 ($\pm$ 0.16) | 121 ($\pm$ 1) |
EXACT ($n$=500) | 98.93 ($\pm$ 0.15) | 58 ($\pm$ 0) |
SAMPLE | 99.12 ($\pm$ 0.12) | 101 ($\pm$ 1) |
EMPIRICAL | 99.12 ($\pm$ 0.12) | 100 ($\pm$ 1) |
BATCHED ($b$=128) | 99.11 ($\pm$ 0.16) | 58 ($\pm$ 1) |
None | 85.41 ($\pm$ 0.88) | 53 ($\pm$ 0) |
From the table we can see that for Split MNIST, when looking only at the performance of the best performing hyperparameter, there are no substantial differences between the various ways of computing the Fisher. However, there are large differences in terms of the range of hyperparameter values that EWC performs well with. For example, when using the BATCHED option of computing the Fisher, EWC requires a hyperparameter orders of magnitude larger than the best hyperparameter for the EXACT option. This suggests that there might be important differences between these different ways of computing the Fisher, but that perhaps the task-incremental version of Split MNIST is not difficult enough to elicit significant differences in the best performance between them.
Therefore, let us look at the results on the more difficult Split CIFAR-10 benchmark. Again, the results are presented in the form of a graph and a table:
Accuracy (in %) | Train time (in S) | |
---|---|---|
EXACT | 84.91 ($\pm$ 0.65) | 972 ($\pm$ 5) |
EXACT ($n$=500) | 84.57 ($\pm$ 0.72) | 668 ($\pm$ 1) |
SAMPLE | 83.77 ($\pm$ 0.72) | 817 ($\pm$ 2) |
EMPIRICAL | 83.28 ($\pm$ 0.77) | 817 ($\pm$ 2) |
BATCHED ($b$=128) | 83.38 ($\pm$ 0.66) | 667 ($\pm$ 1) |
None | 69.53 ($\pm$ 1.34) | 627 ($\pm$ 3) |
Indeed, on this benchmark, there are significant differences between the different options also in terms of their best performance. The performance of EWC is substantially better when the Fisher Information is computed exactly, even when this is done only for a subset of the old training data, compared to when it is estimated or approximated in same way. We can further see that the SAMPLE option, which uses an unbiased estimate of the true Fisher, appears to perform somewhat better than using the empirical Fisher, but the difference is small and non-conclusive. Interestingly, also on this more difficult benchmark, using the batched approximation of the empirical Fisher still results in a similar best performance as using the regular emprical Fisher, although these two options do differ in terms of their optimal hyperparameter range.
I finish this blog post by concluding that the way in which the Fisher Information is computed can have a substantial impact on the performance of EWC. This is an important realization for the continual learning research community. Going forwards, based on my findings, I have three recommendations for researchers in this field. Firstly, whenever using EWC — or another method that uses the Fisher Information — make sure to describe the details of how the Fisher Information is computed. Secondly, do not simply ‘‘use the best performing hyperparameter(s) from another paper’’, especially if you cannot guarantee that the details of your implementation are the same as in the other paper. And thirdly, when using the Fisher Information matrix, it is preferable to compute it exactly rather than approximating it. If computational resources are scarce, it might be better to reduce the number of training samples used to compute the Fisher, than to cut corners in another way.
This work has been supported by a senior postdoctoral fellowship from the Resarch Foundation — Flanders (FWO) under grant number 1266823N.
PLACEHOLDER FOR ACADEMIC ATTRIBUTION
BibTeX citation
PLACEHOLDER FOR BIBTEX