A Practical Way of Implementing Model-Agnostic Meta-Learning (MAML) in PyTorch

·

4 min read

Table of Contents

  1. Introduction
  2. Meta-Agnostic Meta-Learning
  3. Open-source Implementations
  4. Notes on Implementation Tricks
    1. Task-level Dataloader
    2. Base Learner
    3. Meta Learner
  5. References

Introduction

Recently I need to apply meta-learning techniques in my own data mining related research. Model-Agnostic Meta-Learning (MAML)[1], the most famous meta-learning method, serves as an important and basic baseline. So I try to learn some common practices and elegent ways to implement MAML on my own. I firstly seek for popular open-source implementations on GitHub, and then take notes on the common tricks they used.

Meta-Agnostic Meta-Learning

[13] is an excellent material as an introduction to meta-learning as well as the MAML algorithm.

Open-source Implementations

There have been a number of MAML implementation on GitHub, which are listed as follows:

  1. [7] is the standard code from the authors of the original paper. The code is implemented in Tensorflow 1.0.
  2. [2] is the most popular pytorch implementation, which is self-contained and elegant.
  3. [3] contains pytorch implementations of a series of MAML-related methods, which is based on the higher library[5].
  4. [14] is also a self-contained pytorch implementation, but a little out of date as it only supports python2, pytorch 0.2.
  5. [9] is a pytorch implementation based on the TorchMeta library[10].
  6. [4] contains very primary code serving as a part of a meta-learning tutorial.
  7. [8], which is in fact the CS330 course at Harvard University, has a homework in which the starter-code is also a good reference.
  8. [11] shows an example that applies MAML to a practical problem, i.e., a recommendation system.
  9. [6] is a well-designed library for meta-learning.

Among the listed open-source implementations, I would recommend [2] for reading and learning. Its implementation is really elegant so that I can quickly get the idea. My notes generally follow this codebase.

Notes on Implementation Tricks

Below is one practical way of implementing MAML, mainly including three parts: a task-level dataloader, a base learner, and a meta learner.

Task-level Dataloader

Different from vanilla machine learning paradigms, to perform meta-training, the dataloader should return a batch of tasks instead of samples at each iteration. For example, if the task refers to the K-way N-shot image classification, the dataloader could return a tuple (xspt, yspt, xqry, yqry), representing the images and the labels of the support set and the query set, respectively.

Base Learner

The base learner holds the network parameters (e.g., linear layers) similar to an ordinary pytorch module, but these parameters represent meta-parameters in fact. To this end, there comes a difference in defining the parameters. Instead of using the predefined modules in torch.nn (such as nn.Linear()), the base learner directly defines its network structure with nn.Parameter(), and implement the forward computation with nn.functional(). The reason behind is caused by the feature of pytorch library, and a closer look can be taken in [15]. It seems that implementing a complex network structure with manually defined parameters is awkward. In that case, it would be more convinient to use third-party libraries [6,5,10].

The forward function takes the data of one task as input, evaluates (predicts) on the data with parameters, and outputs the loss or logits. The core design of the base learner is to realize the computation graphs of the local update and the global update simultaneously with a single forward function. Specifically, in addition to compute the forward process with its own parameters, the base learner provides an altinative that computes with another set of parameters (given as an additional argument of the forward function default to None).

When performing local update, we maintain local parameters as a copy of meta-parameters separately, and pass them to the forward function. In addition, the input data now is the support set of a task. That is, we build a comnputation graph that updates local parameters via the meta-training loss on the support set.

When doing global update, we use the meta-parameters stored by the base learner directly (i.e., set the additional argument as None) and input the query set of a task. That is, we build a computation graph that updates meta-parameters via the meta-testing loss on the query set.

Meta Learner

The metalearner holds the base learner as a member variable. The forward function of the meta-learner takes a batch of tasks as input, performs local update for each task (by calling the forward function of the base learner), calculates the meta-testing losses (again via the base learner), and optimizes the meta-parameters.

References

[1] Finn, Chelsea, Pieter Abbeel, and Sergey Levine. “Model-agnostic meta-learning for fast adaptation of deep networks.” International conference on machine learning. PMLR, 2017.

[2] https://github.com/dragen1860/MAML-Pytorch

[3] https://github.com/cnguyen10/few_shot_meta_learning

[4] https://github.com/sudharsan13296/Hands-On-Meta-Learning-With-Python

[5] https://github.com/facebookresearch/higher

[6] https://github.com/learnables/learn2learn

[7] https://github.com/cbfinn/maml

[8] https://cs330.stanford.edu/

[9] https://github.com/tristandeleu/pytorch-maml

[10] https://github.com/tristandeleu/pytorch-meta

[11] https://github.com/hoyeoplee/MeLU

[12] Lee, Hoyeop, et al. “Melu: Meta-learned user preference estimator for cold-start recommendation.” Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 2019.

[13] Weng, Lilian, “Meta-Learning: Learning to Learn Fast”, https://lilianweng.github.io/posts/2018-11-30-meta-learning/ , 2018.

[14] https://github.com/katerakelly/pytorch-maml

[15] https://github.com/pytorch/pytorch/issues/12659