Matthew Riemer et. al. Learning to Learn Without Forgetting by Maximizing Transfer and Minimizing Interference. ICLR, 2019.
- Recall catastrophic forgetting, a neural network sequentially trained on multiple tasks forgets earlier tasks with each new task, apparently not a problem in Bayesian networks
- Why? overwriting weights with updates, …
- How to avoid? limit weight sharing, balance network stability vs plasticity (“recall of old tasks” versus “rapid learning of new ones”), …
- The loss function: \[\sum_{i, j} L(x_i , y_i ) + L(x_j , y_j ) − \alpha {\partial L(x_i , y_i ) \over \partial \theta} \cdot {\partial L(x_j , y_j ) \over \partial \theta}\]
- The regularizing term is a measure of transfer or interference between updates. The gradient wrt to learning parameters guides the backprop update to those parameters: alignment of gradients means the updates agree and will guide learning for both examples; anti-alignment means updates cancel and neither example will learn; any intervening overlap is deemed transfer (interference) for positive (negative) values.
- Maximizing weight sharing maximizes transfer; minimizing weight sharing minimizes the change for interference.
-
Work leading up to this paper, both offline algorithms over dataset D:
- MAML -> FOMAML (Finn & Levine, 2017)
- Reptile (Nichol & Schulman, 2018)
-
Contributions: new algorithm MER, meta experience replay, an online algorithm (algorithms 1 with variants 6 & 7):
- added an inner loop within Reptile batches for an inner meta-learning update
- keeps a memory/reservoir of examples M to approximate the full dataset D with new examples added probabilistically to replace old ones (see algorithm 3 in the paper)
- prioritizes learning of the current examples, esp. because it may not be saved
-
First, the reptile algorithm:
- for each epoch of training, \(t\), record the current params, \(\theta^A_0 = \theta_{t-1}\) and sample \(s\) batches of size \(k\)
- perform a normal epoch of training over the \(s\) batches with learning rate \(\alpha\) toward final params \(\theta^A_s\)
- update the network weights for this epoch only a fraction of the learned param changes: \[\theta_t = \theta^A_0 + \gamma (\theta^A_s - \theta^A_0)\]
- this meta-learning update enacts the effective loss \[2\sum_{i=1}^s L(B_i) - \sum_{j=1}^{i-1} {\partial L(B_i) \over \partial \theta} \cdot {\partial L(B_j) \over \partial \theta}\]
-
MER adds a second meta-learning update within each of the \(s\) batches, now sampled from reservoir M, each of which will have the current example in it; finally, the reservoir is updated (maybe)
- for each epoch of training, \(t\), record the current params, \(\theta^A_0 = \theta_{t-1}\) and sample \(s\) batches of size \(k\), include example \(x_t, y_t\) in each
- for each batch \(i\), record the current params, \(\theta^A_{i, 0} = \theta^A_{i-1}\)
- for each example \(j\) in the batch, perform a backprop update with learning rate \(\alpha\) to params \(\theta^A_{i, j}\)
- after the entire batch has been singly learned, meta-learn the parameter update \[\theta^A_i = \theta^A_{i, 0} + \beta (\theta^A_{i, k} - \theta^A_{i, 0})\]
- the effective loss is \[2\sum_{i=1}^s \sum_{j=1}^k L(x_{ij}, y_{ij}) - \sum_{q=1}^{i-1}\sum_{r=1}^{j-1} {\partial L(x_{ij}, y_{ij}) \over \partial \theta} \cdot {\partial L(x_{qr}, y_{qr}) \over \partial \theta}\]
- note that they update the batch examples singly to maximize the regularizing effect
- algorithms 6 & 7 are alternate ways of prioritizing the current example
-
Evaluation metrics:
- learning accuracy (LA): average accuracy for each task immediately after it has been learned
- retained accuracy (RA): final retained accuracy across all tasks learned sequentially
- backward transfer and interference (BTI): the average change in accuracy from when a task is learned to the end of training (positive good; large and negative is catastrophic forgetting)
-
Problems:
- in supervised learning: MNIST permutations, each task is transformed by a fixed permutation of the MNIST pixels; MNIST rotations, each task contains digits rotated by a fixed angle between 0 and 180 degrees; Omniglot, each task is one of 50 alphabets with overall 1623 characters
- in reinforcement learning: Catcher, a board moved left/right to catch a more and more rapidly falling object; Flappy Bird must fly between ever tightening pipes
-
Compared against:
- online, same network trained straightforwardly one example at a time on the incoming non-stationary training data by simply applying SGD
- independent, one model per task with size of network reduced proportionally to keep total number of parameters fixed
- task input, trained as in online with a dedicated input layer per task
- EWC, Elastic Weight Consolidation (EWC) (Kirkpatrick et al., 2017), ~online regularized to avoid catastrophic forgetting
- GEM: Gradient Episodic Memory (GEM) (Lopez-Paz & Ranzato, 2017) uses episodic storage to modify gradients of latest example to not interfere with past ones; stored examples are not used in ongoing training Findings:
- MER seems to do learn and retain the most over all tasks, faster, and with less memory
-
my reservations:
- mnist again?
- omniglot is not usually studied with any of the algorithms compared against: in Lake (2015) they achieve <5% error rate, still <15% in a stripped down version of their model and 2 out of 3 of their baselines
- how much slower will the training be with single example batches and two meta-learning updates?