Jekyll2024-01-30T14:52:17+00:00http://willwhitney.com/feed.xmlwillwhitney.github.ioMy personal site.Learning 3D physics simulators from video2024-01-30T00:00:00+00:002024-01-30T00:00:00+00:00http://willwhitney.com/learned-simulators<p>Where simulation is good enough, it is transforming robotics. But it’s nearly impossible to engineer simulators that can accurately model complex physical and visual properties.</p>
<p>Instead of <em>engineering</em> such a simulator, why not <em>train</em> one? In a new paper, we design an architecture that learns 3D physics simulators by training on multi-view video. It’s essentially a <a href="https://www.matthewtancik.com/nerf">NeRF</a> that learns a particle-based physics simulator in order to predict the future. We call it Visual Particle Dynamics, or VPD.</p>
<p>Because this model is very expressive, and can in principle be trained on real-world data, it has the potential to overcome the limitations of hand-engineered simulators. And because it uses NeRF for rendering, it can produce high-quality images.</p>
<figure class="image">
<img src="../assets/img/learned-simulators/media.gif" />
<figcaption>Simulating mixed-material physics. Left: ground truth, right: learned simulator.</figcaption>
</figure>
<p>This post describes why I’m excited about simulation, the ingredients that make VPD work, and then finally what the VPD model itself does.</p>
<p><a href="https://arxiv.org/abs/2312.05359">Paper link</a></p>
<p><a href="https://sites.google.com/view/latent-dynamics">Video web site</a></p>
<ul id="markdown-toc">
<li><a href="#sim-to-real-is-transforming-robotics" id="markdown-toc-sim-to-real-is-transforming-robotics">Sim to real is transforming robotics</a></li>
<li><a href="#background" id="markdown-toc-background">Background</a> <ul>
<li><a href="#graph-network-simulators" id="markdown-toc-graph-network-simulators">Graph network simulators</a></li>
<li><a href="#nerf" id="markdown-toc-nerf">NeRF</a></li>
<li><a href="#point-nerf" id="markdown-toc-point-nerf">Point-NeRF</a></li>
</ul>
</li>
<li><a href="#visual-particle-dynamics" id="markdown-toc-visual-particle-dynamics">Visual Particle Dynamics</a> <ul>
<li><a href="#overview" id="markdown-toc-overview">Overview</a></li>
<li><a href="#encoder" id="markdown-toc-encoder">Encoder</a></li>
<li><a href="#dynamics" id="markdown-toc-dynamics">Dynamics</a></li>
<li><a href="#rendering" id="markdown-toc-rendering">Rendering</a></li>
</ul>
</li>
<li><a href="#the-simulator-of-the-future" id="markdown-toc-the-simulator-of-the-future">The simulator of the future?</a></li>
<li><a href="#acknowledgements-and-citation" id="markdown-toc-acknowledgements-and-citation">Acknowledgements and citation</a></li>
</ul>
<h1 id="sim-to-real-is-transforming-robotics">Sim to real is transforming robotics</h1>
<p>Training robots in the real world is hard. Collecting demonstrations spends endless human hours on menial labor, and you can’t really set your robot loose to collect data on its own — it will inevitable find a way to break its own fingers and drop your favorite mug on the floor. Every experiment is an endless cycle of watching, debugging, resetting, repairing… On a bad day, I think even the real robot devotees harbor a secret wish that they could do the whole mess in sim.</p>
<p>And if you look around, you notice a striking pattern: making a robot do something where our simulations are “good enough” is becoming trivial, while making a robot do something hard to simulate remains basically impossible:</p>
<ul>
<li>Locomotion, which can be simulated with simplified physics and observations, is dominated by sim to real. Making a robot walk on varied terrain takes a few hours of training.</li>
<li>Manipulation in sim requires modeling complex physics like deformation and surface properties and lighting, so there is little adoption of sim to real. State of the art manipulation involves <a href="https://arxiv.org/abs/2307.15818">collecting 130,000 human demonstrations over 17 months</a> and produces a system that can pick up new objects around 3/4 of the time.</li>
</ul>
<figure class="youtube">
<iframe width="560" height="315" src="https://www.youtube.com/embed/cqvAgcQl6s4?si=epg5BfMmUtVI7yn9" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen="">
</iframe>
<figcaption>Locomotion with sim to real has gotten incredible. Credit: <a href="https://extreme-parkour.github.io/">https://extreme-parkour.github.io/</a></figcaption>
</figure>
<p>So to summarize: sim to real works incredibly well and everybody wishes they could do it, but sim today isn’t good enough to solve most tasks. Maybe the obvious conclusion is that we should all be working on simulation?</p>
<p>Now I’m not suggesting that you go out and start trying to contribute C code to MuJoCo (though that is commendable!). After all, I’m a deep learner at heart. And for us, every problem is a machine learning problem. So what does the deep learning version of building a simulator look like?</p>
<p>Here are some things we might want:</p>
<ol>
<li>You can train it on real data. After all, the goal is to build a simulator that is good enough to use for sim to real. If you can’t train on real data, what’s the point?</li>
<li>It should be more general than a normal simulator. Lots of materials in the real world are at least a little bit deformable, friction is more complicated than $n$ coefficients, there are subscale textures that affect contacts, and no simulator can model <em>everything</em>.</li>
<li>It keeps the “nice” properties of normal simulators. You want to have a 3D representation so you can compose a new scene by adding objects, render a camera moving through it, etc.</li>
</ol>
<p>What would this look like, if we succeeded? We would have <em>some kind</em> of model that can construct a 3D representation from observations, that can handle all the complexities of real dynamics, and that we can use to compose and edit new scenes to simulate.</p>
<p>We’re definitely not there yet.</p>
<p>But our new work is a first step in this direction, and I think it could be the foundation of a new family of simulators.</p>
<h1 id="background">Background</h1>
<p>Before I go into the details of the model, I’ll give some background on the tools that make this possible.</p>
<h2 id="graph-network-simulators">Graph network simulators</h2>
<p>The first bit of background that will come in handy is on graph network learned simulators, or GNS. For the last few years, a group at DeepMind<sup id="fnref:other_gns" role="doc-noteref"><a href="#fn:other_gns" class="footnote" rel="footnote">1</a></sup> has been working with these models across a ton of different physical systems, from <a href="https://arxiv.org/abs/2010.03409">fabric</a> to <a href="https://arxiv.org/abs/2002.09405">granular materials</a> to <a href="https://arxiv.org/abs/2212.03574">rigids</a>.</p>
<div class="multi-figure">
<figure class="subfigure">
<img src="/assets/img/learned-simulators/FIGNet.gif" alt="" style="height: 200px;" />
</figure>
<figure class="subfigure">
<img src="/assets/img/learned-simulators/gns_fabric.gif" alt="" style="height: 200px;" />
</figure>
</div>
<div class="multi-figure">
<figure class="subfigure">
<img src="/assets/img/learned-simulators/gns_water.gif" alt="" style="height: 200px;" />
</figure>
<figure class="subfigure">
<img src="/assets/img/learned-simulators/gns_gel.gif" alt="" style="height: 200px;" />
</figure>
<figure class="subfigure">
<img src="/assets/img/learned-simulators/gns_sand.gif" alt="" style="height: 200px;" />
</figure>
</div>
<p>These models use graph neural networks as dynamics models. They represent a scene as a set of points, which may correspond to particles of a sand or fluid or vertices in a mesh. When encoded as a graph network input, each of those points becomes a node in the graph, and points which have some relationship to each other, like close 3D proximity or a shared mesh edge, get edges connecting them. After doing many steps of message passing among these nodes, the GNS predicts a $\Delta x$ for every node.</p>
<figure class="image">
<img src="/assets/img/learned-simulators/Untitled.png" />
<figcaption>A GNS operating on rigid meshes.</figcaption>
</figure>
<p>The key advantage of these models is that they are <em>local</em> and <em>spatially equivariant</em>. Because of the way information propagates through a GNS, a change in the state at one end of the workspace doesn’t affect predictions at the opposite end. And because the GNS only observes relative positions, and never global ones, it makes the same predictions if you take all the points and move them somewhere else.</p>
<p>When you put these two properties together, you get something really amazing: each single timestep of training data acts like many separate training points. Consider a homogeneous material like sand: each particle of sand follows the same physical laws as every other particle of sand. So if you have one trajectory with a thousand sand particles, the GNS trains <em>almost</em> as if you have a thousand trajectories in your training set. In reality the GNS softly observes a region of particles instead of just one, so maybe it’s like a hundred trajectories, or maybe just ten. But they exhibit impressive data efficiency.</p>
<p>And by the same token, a GNS can generalize compositionally. After all, the model only defines the local behavior of particles. It never even knows whether there are a thousand sand particles or a million, so it doesn’t break if you train on a thousand and test on a million.</p>
<p>So GNS models are pretty incredible. So far, though, they’ve always relied on having ground-truth states: particles which persist and move over time, completely specifying the state of the system. That’s more or less impossible to get in the real world. You can approximate it by having CAD models of everything you want to observe and sticking QR codes all over, but even that only works for rigids.</p>
<h2 id="nerf">NeRF</h2>
<p>The other tool we need is <a href="https://arxiv.org/abs/2003.08934">Neural Radiance Fields</a>, more commonly known as NeRF.</p>
<figure class="video">
<video controls="" width="560">
<source src="/assets/img/learned-simulators/teaser.mp4" type="video/mp4" />
</video>
<figcaption>Credit: Zip-NeRF, <a href="https://jonbarron.info/zipnerf/">https://jonbarron.info/zipnerf/</a></figcaption>
</figure>
<p>You might already be familiar with NeRF, but in case you’re not: NeRF trains a neural network to predict the <em>color</em> and <em>density</em> of a location in space, and renders a <em>ray</em> by querying this network at many XYZ points along the ray.</p>
<figure class="image">
<img src="/assets/img/learned-simulators/Untitled%201.png" />
<figcaption>Credit: NeRF, <a href="https://arxiv.org/abs/2003.08934">https://arxiv.org/abs/2003.08934</a></figcaption>
</figure>
<p>In this way, a trained NeRF encodes the 3D appearance of a single scene. It generalizes over camera positions by construction, but it has no ability to generalize over new scenes. To get a NeRF for a version of the scene where some objects have moved and the weather is cloudy you have to train a new NeRF from scratch.</p>
<p>NeRF’s results are incredible, but there is one specific thing that NeRF does that I think is still deeply under-appreciated years later: NeRF gives us a way to turn 2D supervision into 3D supervision. Multi-view consistency turns out to be the Rosetta stone of 3D machine learning. People (including me! my first paper was about this!) have been wanting to learn things about the 3D world forever, and these efforts have always been frustrated by the fact that we only ever have 2D images. No longer.</p>
<h2 id="point-nerf">Point-NeRF</h2>
<p>An interesting extension to NeRF is <a href="https://xharlie.github.io/projects/project_sites/pointnerf/">Point-NeRF</a>. Instead of having the weights of the model represent the state of the world, Point-NeRF has an explicit representation: points in space tagged with latent vectors. When rendering some location in space $x$, the Point-NeRF MLP doesn’t directly observe $x$ like a NeRF would. Instead, the Point-NeRF gets information derived from points in the local neighborhood, and those points must contain all the features needed for the Point-NeRF to render.</p>
<div class="multi-figure">
<figure class="subfigure">
<img src="/assets/img/learned-simulators/Untitled%202.png" alt="" style="height: px;" />
</figure>
<figure class="subfigure">
<img src="/assets/img/learned-simulators/Untitled%203.png" alt="" style="height: px;" />
</figure>
</div>
<p>With this explicit scene representation, Point-NeRF unlocks the ability to generalize over scenes. Instead of training a new Point-NeRF for each scene, you can just point its image encoder at a few views of the scene and get your point cloud instantly. This extra structure also lets you get away with only a few views per scene instead of the very dense coverage that NeRF needs.</p>
<figure class="image">
<img src="/assets/img/learned-simulators/Untitled%204.png" />
<figcaption>Credit: Point-NeRF, <a href="https://xharlie.github.io/projects/project_sites/pointnerf/">https://xharlie.github.io/projects/project_sites/pointnerf/</a></figcaption>
</figure>
<h1 id="visual-particle-dynamics">Visual Particle Dynamics</h1>
<p>The fundamental limitation of graph network simulators has always been that they require a form of supervision which is totally impractical to collect: ground-truth 3D locations of every point representing a scene at every moment in time. The core insight of VPD is that NeRF’s ability to supervise 3D information using 2D observations exactly addresses this weakness. If we had a GNS that could perceive a scene and render its predictions, we could supervise it with easy-to-collect video.</p>
<figure class="image">
<img src="/assets/img/learned-simulators/squish_dual_rotating_camera-2.gif" />
<figcaption>VPD simulates this poor smooshy ducky in 3D, and can render it from any angle.</figcaption>
</figure>
<p>VPD does exactly this. It applies the simulation architecture of a GNS to the particle representation of a Point-NeRF, combining perception, dynamics, and rendering in one pipeline that can be trained end to end.</p>
<h2 id="overview">Overview</h2>
<p><img src="../assets/img/learned-simulators/latent_physics_v1_3.png" alt="Visual Particle Dynamics" /></p>
<p>Visual Particle Dynamics, or VPD, has three key components:</p>
<ul>
<li>First is an encoder that takes one or more posed images describing a scene and encodes them into a set of particles in 3D space, each with a latent feature vector.</li>
<li>Next is a dynamics model that transforms the observed particles at step $t$ to predicted particles at step $t+1$.</li>
<li>Finally we have a renderer, which uses ray-based rendering to decode the set of points into an image from a queried camera pose.</li>
</ul>
<p>Putting it all together, VPD is a video prediction model: given a couple of frames, it can predict the rest of the video. But it’s also more than that. Because it has a 3D representation of the world and a 3D renderer, it supports 3D editing of the underlying scene during simulation and arbitrary camera motion. Most importantly of all, it can be trained end-to-end on data you could actually collect and still keep all these nice 3D tricks.</p>
<h2 id="encoder">Encoder</h2>
<figure class="image">
<img src="/assets/img/learned-simulators/Untitled%205.png" />
<figcaption>The VPD encoder applies a convnet to produce a latent vector for every pixel, then uses depth data and camera geometry to turn those latent vectors into 3D particles.</figcaption>
</figure>
<p>The first component of VPD is the encoder, which turns every pixel in the input images into a particle in 3D space, each one tagged with a latent vector. For Point-NeRF, the latent vectors carry information about the visual appearance of that point, but since VPD is trained end to end its latent vectors will also capture something about geometry and physical properties.</p>
<p>These observations (and particles) can come from however many cameras as we have. For now we assume that our data is in the form of multi-view RGB-D videos, such as would be recorded by multiple depth cameras, but in the paper we find that using a monocular depth estimation model works almost as well.</p>
<h2 id="dynamics">Dynamics</h2>
<p>Next we have a dynamics model in the form of a graph network. This GNN takes the particles from the encoder and predicts how they will change. It moves them in 3D space, representing motion, and changes the features of the latent vector, representing changes in appearance such as shadows.</p>
<figure class="image">
<img src="/assets/img/learned-simulators/Untitled%206.png" />
<figcaption>A hierarchical graph network. Past and present observations (blue and green particles) send information to abstract nodes. The abstract nodes communicate amongst themselves, then predict the motion of the present particles.</figcaption>
</figure>
<p>Previous GNS architectures have used the velocity of each point as an input to the model. They assumed access to ground-truth 3D information, including correspondences between points at different times, but camera data has no such thing; there is no easy way to “track” the motion of every pixel in a video, especially not with occlusion. Instead, VPD simply gives all the particles from multiple timesteps to a giant graph network asks it to figure it out.</p>
<figure class="image">
<img src="/assets/img/learned-simulators/Untitled%207.png" width="400px" />
<figcaption>The graphs are big enough to be… challenging.</figcaption>
</figure>
<p>Multiple timesteps’ worth of pixels from multiple cameras rapidly becomes a tremendous number of points. To deal with these huge graphs more efficiently, VPD uses a hierarchical graph network, where most of the message passing work happens on a small set of more abstract particles. All the messy details of this are in <a href="https://arxiv.org/abs/2312.05359">the paper</a>, but really this is just a computational optimization. This GNN acts much like those in previous GNS models, with the dynamics of one side of the scene largely independent of the other.</p>
<h2 id="rendering">Rendering</h2>
<p>Finally, after the dynamics model predicts how the particles move and change over time, the renderer can take that prediction and turn it into an image. Just like Point-NeRF, the VPD renderer predicts color and density at a location in space given features from the particles near that location. That means that in order to render some $(x, y, z)$ correctly, the right particles need to be nearby.</p>
<figure class="image">
<img src="/assets/img/learned-simulators/Untitled.gif" />
<figcaption>Poor ducky can never catch a break.</figcaption>
</figure>
<p>This is the engine that drives training for the whole model: if our ducky ends up with its face at $(2, -1, 1)$, and the particles representing that face are predicted at $(2, -1, 1.5)$, the dynamics model will get a gradient to move those particles down to the right place so that the renderer can put the right colors there.</p>
<h1 id="the-simulator-of-the-future">The simulator of the future?</h1>
<p>As all the gifs in this post show, VPD has some capabilities that normal simulators would struggle on, like simulating deformables and making renders match ground truth. But what I’m most excited about is what it <em>shares</em> with other simulators: an interpretable, editable 3D representation of a scene. Unlike a 2D video model, you can “clone” a real scene by passing it through the VPD encoder, and then interact with its particle representation to make something new. Here’s an example of removing certain objects from the scene:</p>
<div class="multi-figure">
<figure class="subfigure">
<img src="/assets/img/learned-simulators/media-2.gif" alt="" style="height: px;" />
<figcaption>Original</figcaption>
</figure>
<figure class="subfigure">
<img src="/assets/img/learned-simulators/media-3.gif" alt="" style="height: px;" />
<figcaption>Removed cylinder</figcaption>
</figure>
<figure class="subfigure">
<img src="/assets/img/learned-simulators/media-4.gif" alt="" style="height: px;" />
<figcaption>Removed floor</figcaption>
</figure>
</div>
<p>This is a simple edit, but you could imagine creating a library of objects (with a few images and one forward pass each) and re-combining them at will. The particle representation is easy to interpret and visualize, making editing much simpler.</p>
<p>Since VPD models dynamics and appearance as local properties, it supports the composition and simulation of much larger scenes at test time. Just grab all the particles you want from various examples and stick them together.</p>
<figure class="image">
<img src="/assets/img/learned-simulators/media-5.gif" width="400px" />
<figcaption>So many squishies!</figcaption>
</figure>
<p>This is just the beginning, and there are so many research directions on top of VPD. There are various improvements that I see as low-hanging fruit:</p>
<ul>
<li>Currently the model is fully deterministic, which leads to blurry predictions when there is uncertainty or chaotic dynamics. I would love to see a fully generative version of VPD, which could challenge 2D video models in terms of image quality while maintaining VPD’s highly accurate dynamics.</li>
<li>VPD knows a lot about dynamics, but nothing about semantics. It would be interesting to run something like <a href="https://segment-anything.com/">Segment Anything</a> on the images to add object identity as a feature to every particle. Since VPD particles track objects throughout their rollouts, this would straightforwardly enable predictions of object-level motion and behavior.</li>
<li>The hierarchical GNS architecture is very memory-intensive, and efficiency improvements would lead to big quality improvements. We currently use $2^{14}$ points and train on 6-step rollouts, and we know that increasing those numbers produces big gains in fidelity, not to mention making it possible to train on very large scenes.</li>
</ul>
<p>Perhaps in the future robotics will rely on simulators like these. Imagine cloning a new task into simulation just by pointing a camera at it, then fine-tuning a policy to solve it via sim to real. Or a simulator that matches the real world more closely with every experiment you run on hardware, continually adapting to new objects and becoming more capable. Or the “foundation model” of learned simulators, trained on all the world’s data and in turn able to safely train robots to interact with humans, cars, and the world outside the lab.</p>
<p>That’s a future of robotics that I’d like to see.</p>
<h1 id="acknowledgements-and-citation">Acknowledgements and citation</h1>
<p>This work would not have been possible without my amazing co-authors, Tatiana Lopez-Guevara, Tobias Pfaff, Yulia Rubanova, Thomas Kipf, Kim Stachenfeld, and Kelsey Allen. This post only represents my own views, though, so don’t blame them if I’ve said something outrageous.</p>
<p>Here’s the bibtex for our paper, in case it becomes relevant in your own work:</p>
<div class="language-latex highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@inproceedings<span class="p">{</span>
whitney2024learning,
title=<span class="p">{</span>Learning 3D Particle-based Simulators from <span class="p">{</span>RGB<span class="p">}</span>-D Videos<span class="p">}</span>,
author=<span class="p">{</span>William F. Whitney and Tatiana Lopez-Guevara and Tobias Pfaff and Yulia Rubanova and Thomas Kipf and Kimberly Stachenfeld and Kelsey R. Allen<span class="p">}</span>,
booktitle=<span class="p">{</span>The Twelfth International Conference on Learning Representations<span class="p">}</span>,
year=<span class="p">{</span>2024<span class="p">}</span>,
url=<span class="p">{</span>https://openreview.net/forum?id=4rBEgZCubP<span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:other_gns" role="doc-endnote">
<p>And others in parallel, e.g. <a href="http://dpi.csail.mit.edu">DPI-Net</a> from a group at MIT.</p>
</li>
</ol>
</div>Where simulation is good enough, it is transforming robotics. But it’s nearly impossible to engineer simulators that can accurately model complex physical and visual properties.Parallelizing neural networks on one GPU with JAX2021-01-24T00:00:00+00:002021-01-24T00:00:00+00:00http://willwhitney.com/parallel-training-jax<!-- Code for this post was written in https://github.com/willwhitney/jax-parallel -->
<p>Most neural network libraries these days give amazing computational performance for training <em>large</em> neural networks.
But small networks, which aren’t big enough to usefully “fill” a GPU, leave a lot of available compute unused.
Running a small network on a GPU is a bit like buying an apartment building and then living in the janitor’s closet.</p>
<p>In this article, I describe how to get your money’s worth by training dozens of networks at once.
As you follow along, we’ll efficiently train dozens of small neural networks in parallel on a single GPU using the <a href="https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap"><code class="language-plaintext highlighter-rouge">vmap</code></a> function from <a href="https://github.com/google/jax">JAX</a>.
Whether you are training ensembles, sweeping over hyperparameters, or averaging across random seeds, this technique can give you a 10x-100x improvement in computation time.
If you haven’t tried JAX yet, this may give you a reason to.</p>
<p>All of this was originally implemented as part of my library for evaluating representations, <a href="https://github.com/willwhitney/reprieve">reprieve</a>.
If you’re interested in learning about the pitfalls of representation learning research and how to avoid them, I wrote <a href="/representation-quality-and-the-complexity-of-learning.html">a blog post</a> on that too.</p>
<p>If you’re just here for the code, there’s a <a href="https://colab.research.google.com/drive/1-hVEZ8jck2nzIqmRgSmjQvxJp1wO2HI5?usp=sharing">colab</a> that has what you need.</p>
<!-- ![Comparison of a single network versus a bootstrapped ensemble](/assets/img/bootstrap_compare.png) -->
<figure class="image">
<img src="/assets/img/bootstrap_compare.png" />
<figcaption>Comparison of a single network versus a bootstrapped ensemble. With parallel training, ensembles of small networks are just as quick to train as a single net.</figcaption>
</figure>
<!--
<div id="teaser_chart" class="chart"></div>
<script src="/assets/js/bootstrap_compare_spec.js"></script>
<script>
var embedOpt = {"mode": "vega-lite"};
vegaEmbed("#teaser_chart", spec, embedOpt);
</script> -->
<!-- **Table of contents**
-->
<!-- **Table of contents** -->
<ul id="markdown-toc">
<li><a href="#the-difficulty-of-accelerating-small-networks" id="markdown-toc-the-difficulty-of-accelerating-small-networks">The difficulty of accelerating small networks</a></li>
<li><a href="#large-batches-fill-gpus-but-learn-worse" id="markdown-toc-large-batches-fill-gpus-but-learn-worse">Large batches fill GPUs but learn worse</a></li>
<li><a href="#training-more-networks-in-parallel" id="markdown-toc-training-more-networks-in-parallel">Training more networks in parallel</a> <ul>
<li><a href="#automatic-batching-with-jax-and-vmap" id="markdown-toc-automatic-batching-with-jax-and-vmap">Automatic batching with JAX and <code class="language-plaintext highlighter-rouge">vmap</code></a></li>
<li><a href="#a-first-draft-of-parallel-network-training-with-vmap" id="markdown-toc-a-first-draft-of-parallel-network-training-with-vmap">A first draft of parallel network training with <code class="language-plaintext highlighter-rouge">vmap</code></a></li>
</ul>
</li>
<li><a href="#bootstrapped-ensembles" id="markdown-toc-bootstrapped-ensembles">Bootstrapped ensembles</a> <ul>
<li><a href="#a-bootstrapped-data-sampler" id="markdown-toc-a-bootstrapped-data-sampler">A bootstrapped data sampler</a></li>
<li><a href="#training-the-bootstrapped-ensemble" id="markdown-toc-training-the-bootstrapped-ensemble">Training the bootstrapped ensemble</a></li>
</ul>
</li>
<li><a href="#conclusion" id="markdown-toc-conclusion">Conclusion</a></li>
</ul>
<h2 id="the-difficulty-of-accelerating-small-networks">The difficulty of accelerating small networks</h2>
<p>With the end of Moore’s Law-style clock speed scaling, modern high-performance computing platforms get good performance not by taking less time for a single operation, but by doing more in parallel.
They are <em>wider</em>, not <em>faster</em>.
This applies to accelerators like GPUs and TPUs, or even Apple’s new <a href="https://debugger.medium.com/why-is-apples-m1-chip-so-fast-3262b158cba2">laptop SoCs</a>.</p>
<p>The operations used in neural network training are pretty ideal for taking advantage of very wide architectures.
Large matrix multiplies consist of huge numbers of smaller operations that can be executed at the same time.
On top of that, we always use minibatch training, where we compute a loss gradient on tens, hundreds, or thousands of examples in parallel, then average those gradients to estimate the “true” gradient on the dataset.</p>
<p>Modern automatic differentiation libraries like <a href="https://pytorch.org">PyTorch</a> are optimized for squeezing as much performance as possible out of a wide accelerator for these kinds of workloads.
Train a ResNet-50 on your GPU and you’ll likely see GPU utilization numbers up near 100%, indicating that PyTorch is squeezing all the speed possible out of your hardware.</p>
<p>However, for training <em>small</em> networks, we run into fundamental limits of parallelization.
To be sure, a two-layer MLP will run much faster than the ResNet-50.
But the ResNet has about 4B multiply-accumulate operations, while the MLP has only 100K.<sup id="fnref:flop_counter" role="doc-noteref"><a href="#fn:flop_counter" class="footnote" rel="footnote">1</a></sup>
Much as we might like it to, our MLP will not train 40,000 times faster than a ResNet, and if we inspect our GPU utilization we can see why.
Unlike the ResNet, which uses ~100% of the GPU, the MLP may only use 2-3%.</p>
<p>A simple explanation is that their computation graphs aren’t as wide as the GPU is.
Glossing over a ton of complexity about data loading, fixed costs per loop, and how GPUs actually work, a small network with a reasonable batch size just doesn’t have enough parallelizable operations to use the entire GPU efficiently.</p>
<!--
mention result that tiny networks are just as fast on a mac mini as on a V100?
https://wandb.ai/vanpelt/m1-benchmark/reports/Can-Apple-s-M1-help-you-train-models-faster-cheaper-than-NVIDIA-s-V100---VmlldzozNTkyMzg
-->
<h2 id="large-batches-fill-gpus-but-learn-worse">Large batches fill GPUs but learn worse</h2>
<p>One way to use more compute in parallel would be to increase the batch size.
Instead of using batches of, say, 128 elements, we could crank that up until we fill the GPU.
In fact, why not use the entire dataset as one batch and parallelize across every element!</p>
<p>On MNIST, we can actually try this.
With a batch size of 128, we see GPU utilization at ~2% and a speed of about 11s / epoch.
By caching the entire dataset in GPU memory and performing full-batch gradient descent (i.e. using the whole dataset as one batch), we can get up to a frankly disturbing 0.01s / epoch with 97% GPU utilization!</p>
<!-- TODO: ADD PLOT -->
<p>Unfortunately, this doesn’t correspond to faster learning; the resulting model only has 79% test accuracy, compared to 98% for our small-batch model.
I was lazy and didn’t bother to adjust any hyperparameters, and assuredly we could squeeze out a bit more performance with careful tuning.
However, as a rule we don’t expect very large batch gradient descent to yield performance that’s as good as small batch.<sup id="fnref:keskar" role="doc-noteref"><a href="#fn:keskar" class="footnote" rel="footnote">2</a></sup></p>
<p>Even ignoring issues of generalization error, large batches aren’t very computationally efficient for most problems.
A theoretical argument against very large batches comes from a classic paper by Bottou & Bousquet.<sup id="fnref:bottou" role="doc-noteref"><a href="#fn:bottou" class="footnote" rel="footnote">3</a></sup>
We can think of a full-batch gradient descent update as being very accurate, but computationally expensive, and a small-batch update as being highly approximate, but very cheap.
Bottou & Bousquet show that taking lots of approximate updates results in much faster learning than taking fewer accurate updates.</p>
<p>While using very large batches can definitely saturate our GPU, they don’t actually help us train a small network any faster.
So what can we do instead?</p>
<h2 id="training-more-networks-in-parallel">Training more networks in parallel</h2>
<p>We can’t train a small network much faster on our current hardware, at least not without any exotic tricks.
If we’re careful to make sure our data loading is quick, we can probably train our two-layer MLP 400x faster than a ResNet-50.
And that’s pretty fast!
But we’re still leaving an additional 100x improvement on the table, at least according to our ballpark estimate that the ResNet uses 40,000x as much compute.</p>
<p>What should we do with the whole rest of our GPU?
Well, in practice we often don’t just want to train <em>one</em> neural network.
We might want to run with many random seeds to be confident in our results, or we could sweep over different hyperparameter settings, or we could even train a <a href="https://en.wikipedia.org/wiki/Bootstrapping_%28statistics%29">bootstrapped</a> ensemble of networks for higher accuracy.
Instead of waiting for our first network to finish training before starting the next, we could run all of these experiments at the same time!</p>
<p>The simplest version of this is to run our training script multiple times, putting several copies of (almost) the same job on the same GPU.
This works, and has the advantage of flexibility: those jobs don’t have to have anything in common, they just have to stay within the memory budget of the GPU.
It treats your GPU like a tiny cluster, agnostic to the content of the job scripts.
However, this deployment strategy has a lot of overhead: you get multiple Python processes, multiple copies of the library on the GPU, multiple data transfer calls between CPU and GPU… the list goes on.
In practice you can run a few small jobs on one GPU, but you’ll run out of GPU memory and clog your GPU before you get to 100.</p>
<p>Instead we’re going to see how to avoid duplicating work for our computer by writing an ordinary training step function, then using JAX to batch the computation over many neural networks at once.</p>
<h3 id="automatic-batching-with-jax-and-vmap">Automatic batching with JAX and <code class="language-plaintext highlighter-rouge">vmap</code></h3>
<p>A lot has been written about JAX in the past, so I’ll give only a cursory introduction.
<a href="https://github.com/google/jax">JAX</a> is an exciting new library for fast differentiable computation with support for accelerators like GPUs and TPUs.
It is not a neural network library; in a nutshell, it’s a library that you could build a neural network library on top of.
At the core of JAX are a few functions which take in functions as arguments and return <em>new</em> functions, which are transformed versions of the old ones.
It also includes an accelerator-backed version of <a href="https://numpy.org">numpy</a>, packaged as <code class="language-plaintext highlighter-rouge">jax.numpy</code>, which I will refer to as <code class="language-plaintext highlighter-rouge">jnp</code>.
For more I’ll refer you to <a href="https://blog.evjang.com/2019/02/maml-jax.html">Eric Jang’s nice introduction</a>, which covers meta-learning in JAX, and the <a href="https://jax.readthedocs.io/en/latest/notebooks/quickstart.html">JAX quickstart guide</a>.</p>
<p>On to a quick summary of the core functions in JAX that we’ll use, along with the arguments we care about:</p>
<p><a href="https://jax.readthedocs.io/en/latest/jax.html#jax.jit"><code class="language-plaintext highlighter-rouge">jax.jit(fun, ...)</code></a>: Takes a function <code class="language-plaintext highlighter-rouge">fun</code> and returns a faster version. There’s more to it than this, but enough for our purposes.</p>
<p><a href="https://jax.readthedocs.io/en/latest/jax.html#jax.grad"><code class="language-plaintext highlighter-rouge">jax.grad(fun, ...)</code></a>: Takes a function <code class="language-plaintext highlighter-rouge">fun</code> and returns a function that computes its <em>gradient</em> with respect to (by default) the first argument. For example, we could define <code class="language-plaintext highlighter-rouge">g = jax.grad(lambda x: x**2)</code>, and then use it by calling <code class="language-plaintext highlighter-rouge">g(2)</code>.
In essence, what we did was:</p>
<ol>
<li>Define $f(x) = x^2$ using a lambda expression.</li>
<li>Make a new function $g = \frac{df}{dx}(x) = 2x$.</li>
<li>Evaluated $g(2) \rightarrow 4$.</li>
</ol>
<p><a href="https://jax.readthedocs.io/en/latest/jax.html#jax.vmap"><code class="language-plaintext highlighter-rouge">jax.vmap(fun, in_axes=0, ...)</code></a>:
Takes a function <code class="language-plaintext highlighter-rouge">fun</code> and returns a <em>batched</em> version of that function by vectorizing each input along the axis specified by <code class="language-plaintext highlighter-rouge">in_axes</code>.
<code class="language-plaintext highlighter-rouge">vmap</code> is short for vectorized map; if you’re familiar with <code class="language-plaintext highlighter-rouge">map</code> in other programming languages, <code class="language-plaintext highlighter-rouge">vmap</code> is a similar idea.</p>
<p>Under the hood <code class="language-plaintext highlighter-rouge">vmap</code> does basically the same thing that you would if you were vectorizing something by hand.
For instance, suppose we had a PyTorch function <code class="language-plaintext highlighter-rouge">def f(a, b): torch.mm(a, b)</code> that applies to matrices, but we are given a <em>batch</em> of matrices at a time.
We could compute the answer with a <code class="language-plaintext highlighter-rouge">for</code> loop, but it would be slow.
Instead we can look up the batched PyTorch function which computes batched matrix multiply (it’s <code class="language-plaintext highlighter-rouge">bmm</code>), and define the function <code class="language-plaintext highlighter-rouge">def vf(a, b): torch.bmm(a, b)</code>.
We have transformed <code class="language-plaintext highlighter-rouge">f</code> by hand into a vectorized version, <code class="language-plaintext highlighter-rouge">vf</code>.</p>
<p>In JAX we can do the same thing automatically using <code class="language-plaintext highlighter-rouge">vmap</code>.
If we had a function <code class="language-plaintext highlighter-rouge">def f(a, b): jnp.matmul(a, b)</code>, we could simply do <code class="language-plaintext highlighter-rouge">v = jax.vmap(f)</code>.
Crucially, this doesn’t just work for primitive functions.
You can call <code class="language-plaintext highlighter-rouge">vmap</code> on functions that are almost arbitrarily complicated, including functions that include <code class="language-plaintext highlighter-rouge">jax.grad</code>.</p>
<p>One subtlety here is in the use of <code class="language-plaintext highlighter-rouge">in_axes</code>.
Say that instead of taking a batch of <code class="language-plaintext highlighter-rouge">a</code> and a batch of <code class="language-plaintext highlighter-rouge">b</code>, we wanted a version of our function <code class="language-plaintext highlighter-rouge">f</code> that takes a batch of <code class="language-plaintext highlighter-rouge">a</code>, but only a <em>single</em> <code class="language-plaintext highlighter-rouge">b</code>, and gives us back <code class="language-plaintext highlighter-rouge">jnp.matmul(a[i], b)</code> for each <code class="language-plaintext highlighter-rouge">a[i]</code>.
We can define this new function <code class="language-plaintext highlighter-rouge">v0</code>, which is vectorized only with respect to argument 0, with the following call: <code class="language-plaintext highlighter-rouge">v0 = jnp.vmap(f, in_axes=(0, None))</code>.
This asks that argument 0 of <code class="language-plaintext highlighter-rouge">f</code> be vectorized with respect to axis 0, and argument 1 not be vectorized at all.
Our result will be the same as if we iterated over the first dimension of <code class="language-plaintext highlighter-rouge">a</code>, and used the same value of <code class="language-plaintext highlighter-rouge">b</code> each time.</p>
<h3 id="a-first-draft-of-parallel-network-training-with-vmap">A first draft of parallel network training with <code class="language-plaintext highlighter-rouge">vmap</code></h3>
<p>Now that we have the basics of JAX, we can start implementing a parallel training scheme.
The basic idea is simple: we will write a function that creates a neural network, and a function that updates that network, and then we’ll call <code class="language-plaintext highlighter-rouge">vmap</code> on them.
For full code please refer to the <a href="https://colab.research.google.com/drive/1-hVEZ8jck2nzIqmRgSmjQvxJp1wO2HI5?usp=sharing">colab</a> that accompanies this post.</p>
<p>I’ve defined a simple classification dataset: two spirals in 2D.
We can control the amount of noise in the data and how tight the spiral is.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">make_spirals</span><span class="p">(</span><span class="n">n_samples</span><span class="p">,</span> <span class="n">noise_std</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">rotations</span><span class="o">=</span><span class="mf">1.</span><span class="p">):</span>
<span class="n">ts</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">n_samples</span><span class="p">)</span>
<span class="n">rs</span> <span class="o">=</span> <span class="n">ts</span> <span class="o">**</span> <span class="mf">0.5</span>
<span class="n">thetas</span> <span class="o">=</span> <span class="n">rs</span> <span class="o">*</span> <span class="n">rotations</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">pi</span>
<span class="n">signs</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="p">(</span><span class="n">n_samples</span><span class="p">,))</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">labels</span> <span class="o">=</span> <span class="p">(</span><span class="n">signs</span> <span class="o">></span> <span class="mi">0</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span>
<span class="n">xs</span> <span class="o">=</span> <span class="n">rs</span> <span class="o">*</span> <span class="n">signs</span> <span class="o">*</span> <span class="n">jnp</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">thetas</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n_samples</span><span class="p">)</span> <span class="o">*</span> <span class="n">noise_std</span>
<span class="n">ys</span> <span class="o">=</span> <span class="n">rs</span> <span class="o">*</span> <span class="n">signs</span> <span class="o">*</span> <span class="n">jnp</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">thetas</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">n_samples</span><span class="p">)</span> <span class="o">*</span> <span class="n">noise_std</span>
<span class="n">points</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">stack</span><span class="p">([</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">points</span><span class="p">,</span> <span class="n">labels</span>
<span class="n">points</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">make_spirals</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="n">noise_std</span><span class="o">=</span><span class="mf">0.05</span><span class="p">)</span>
</code></pre></div></div>
<div id="two_spirals_chart" class="chart"></div>
<script src="/assets/js/two_spirals_spec.js"></script>
<script>
var embedOpt = {"mode": "vega-lite"};
vegaEmbed("#two_spirals_chart", spec, embedOpt);
</script>
<!-- We can make this dataset and plot it:
```python
points, labels = make_spirals(100, noise_std=0.05)
df = pd.DataFrame({'x': points[:, 0], 'y': points[:, 1], 'label': labels})
alt.Chart(df, width=350, height=300).mark_circle().encode(
x='x', y='y', color='label:N')
``` -->
<p>For our neural network, we can create a simple MLP classifier in <a href="https://github.com/google/flax">Flax</a>, a neural network library built on top of JAX:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">MLPClassifier</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="n">hidden_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">hidden_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">n_classes</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span>
<span class="o">@</span><span class="n">nn</span><span class="p">.</span><span class="n">compact</span>
<span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">hidden_layers</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">hidden_dim</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">n_classes</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">log_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span>
</code></pre></div></div>
<p>Because JAX is a <a href="https://en.wikipedia.org/wiki/Functional_programming">functional language</a>, we will carry around the <em>state</em> of a network separately from the functions that update or use that state.
Neural network libraries in JAX are still something of a work in progress and the abstractions aren’t terribly intuitive yet.
Somewhat confusingly, instantiating a Flax <code class="language-plaintext highlighter-rouge">nn.Module</code> returns an object with some automatically-generated functions, not a neural network state.
A full description of how this works is available in <a href="https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html#Defining-your-own-models">the Flax docs</a>, but for now you can gloss over anything to do with Flax.</p>
<p>We can instantiate our model functions, then define a couple of helper functions for evaluating our loss.
Here <code class="language-plaintext highlighter-rouge">value_and_grad</code> is a JAX function that acts like <code class="language-plaintext highlighter-rouge">jax.grad</code>, except that the function it returns will produce both $f(x)$ and $\nabla_x f(x)$.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">classifier_fns</span> <span class="o">=</span> <span class="n">MLPClassifier</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">cross_entropy</span><span class="p">(</span><span class="n">logprobs</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
<span class="n">one_hot_labels</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">logprobs</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="k">return</span> <span class="o">-</span><span class="n">jnp</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">one_hot_labels</span> <span class="o">*</span> <span class="n">logprobs</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">classifier_fns</span><span class="p">.</span><span class="nb">apply</span><span class="p">({</span><span class="s">'params'</span><span class="p">:</span> <span class="n">params</span><span class="p">},</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">cross_entropy</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
<span class="k">return</span> <span class="n">loss</span>
<span class="n">loss_and_grad_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">value_and_grad</span><span class="p">(</span><span class="n">loss_fn</span><span class="p">)</span>
</code></pre></div></div>
<p>We’re ready now to create functions to make, train, and evaluate neural networks.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">init_fn</span><span class="p">(</span><span class="n">input_shape</span><span class="p">,</span> <span class="n">seed</span><span class="p">):</span>
<span class="n">rng</span> <span class="o">=</span> <span class="n">jr</span><span class="p">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> <span class="c1"># jr = jax.random
</span> <span class="n">dummy_input</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="o">*</span><span class="n">input_shape</span><span class="p">))</span>
<span class="n">params</span> <span class="o">=</span> <span class="n">classifier_fns</span><span class="p">.</span><span class="n">init</span><span class="p">(</span><span class="n">rng</span><span class="p">,</span> <span class="n">dummy_input</span><span class="p">)[</span><span class="s">'params'</span><span class="p">]</span> <span class="c1"># do shape inference
</span> <span class="n">optimizer_def</span> <span class="o">=</span> <span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer_def</span><span class="p">.</span><span class="n">create</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
<span class="k">return</span> <span class="n">optimizer</span>
<span class="o">@</span><span class="n">jax</span><span class="p">.</span><span class="n">jit</span> <span class="c1"># jit makes it go brrr
</span><span class="k">def</span> <span class="nf">train_step_fn</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">optimizer</span><span class="p">.</span><span class="n">target</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">loss_and_grad_fn</span><span class="p">(</span><span class="n">optimizer</span><span class="p">.</span><span class="n">target</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span><span class="p">.</span><span class="n">apply_gradient</span><span class="p">(</span><span class="n">grad</span><span class="p">)</span>
<span class="k">return</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span>
<span class="o">@</span><span class="n">jax</span><span class="p">.</span><span class="n">jit</span> <span class="c1"># jit makes it go brrr
</span><span class="k">def</span> <span class="nf">predict_fn</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">classifier_fns</span><span class="p">.</span><span class="nb">apply</span><span class="p">({</span><span class="s">'params'</span><span class="p">:</span> <span class="n">optimizer</span><span class="p">.</span><span class="n">target</span><span class="p">},</span> <span class="n">x</span><span class="p">)</span>
</code></pre></div></div>
<p>This provides the entire API we’ll use for interacting with a neural network.
To see how to use it, let’s train a network to solve the spirals!
Here we’re using the entire dataset of <code class="language-plaintext highlighter-rouge">(points, labels)</code> as one batch.
Later we’ll deal with proper data handling.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model_state</span> <span class="o">=</span> <span class="n">init_fn</span><span class="p">(</span><span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,),</span> <span class="n">seed</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
<span class="n">model_state</span><span class="p">,</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">train_step_fn</span><span class="p">(</span><span class="n">model_state</span><span class="p">,</span> <span class="p">(</span><span class="n">points</span><span class="p">,</span> <span class="n">labels</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
<span class="o">-></span> <span class="mf">0.011945118</span>
</code></pre></div></div>
<div id="mlp_pred_chart" class="chart"></div>
<script src="/assets/js/mlp_pred_spec.js"></script>
<script>
var embedOpt = {"mode": "vega-lite"};
vegaEmbed("#mlp_pred_chart", spec, embedOpt);
</script>
<p>Unsurprisingly, this works pretty well.</p>
<p>If our hypothesis is correct, we can plug these functions into <code class="language-plaintext highlighter-rouge">vmap</code> and be embarrassingly parallel!
We will use <code class="language-plaintext highlighter-rouge">in_axes</code> with <code class="language-plaintext highlighter-rouge">vmap</code> to control which axes we parallelize over.
In our <code class="language-plaintext highlighter-rouge">init_fn</code>, which takes the input shape as an argument, we want to use the same input shape for every network, so we set the corresponding element of <code class="language-plaintext highlighter-rouge">in_axes</code> to <code class="language-plaintext highlighter-rouge">None</code>.
For now all of our networks will update on the same batch at each step, so in defining <code class="language-plaintext highlighter-rouge">parallel_train_step_fn</code> we parallelize over the model state, but not over the batch of data: <code class="language-plaintext highlighter-rouge">in_axes=(0, None)</code>.
The number of random seeds we feed in to <code class="language-plaintext highlighter-rouge">parallel_init_fn</code> will determine how many networks we train, in this case 10.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">parallel_init_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">init_fn</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="bp">None</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="n">parallel_train_step_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">train_step_fn</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">None</span><span class="p">))</span>
<span class="n">K</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">seeds</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">K</span><span class="p">)</span>
<span class="n">model_states</span> <span class="o">=</span> <span class="n">parallel_init_fn</span><span class="p">((</span><span class="mi">2</span><span class="p">,),</span> <span class="n">seeds</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
<span class="n">model_states</span><span class="p">,</span> <span class="n">losses</span> <span class="o">=</span> <span class="n">parallel_train_step_fn</span><span class="p">(</span><span class="n">model_states</span><span class="p">,</span> <span class="p">(</span><span class="n">points</span><span class="p">,</span> <span class="n">labels</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span>
<span class="o">-></span> <span class="p">[</span><span class="mf">0.01194512</span> <span class="mf">0.01250279</span> <span class="mf">0.01615315</span> <span class="mf">0.01403342</span> <span class="mf">0.01800855</span> <span class="mf">0.01515956</span>
<span class="mf">0.00658712</span> <span class="mf">0.00957206</span> <span class="mf">0.00750575</span> <span class="mf">0.00901282</span><span class="p">]</span>
</code></pre></div></div>
<p>Great! We can train 10 networks all at once with the same code that used to train one!
Furthermore, this code takes <em>almost exactly</em> the same amount of time to run as when we were only training one network; I get roughly a second for each run, with the parallel version being faster than the single run as often as not.<sup id="fnref:jit_time" role="doc-noteref"><a href="#fn:jit_time" class="footnote" rel="footnote">4</a></sup>
We can also bump <code class="language-plaintext highlighter-rouge">K</code> up higher to see what happens.
On my Titan X, training 100 networks still takes the same amount of time as training only one!
This basically delivers our missing 100x speedup:</p>
<blockquote>
<p>We can’t train <em>one</em> MLP 40,000x as fast as a ResNet, but we can train <em>100</em> MLPs 400x as fast as a ResNet.</p>
</blockquote>
<!-- One other interesting thing to note is that the loss for the solo network and for the first network in the batch is the same (up to machine precision) because they were initialized with the same seed. -->
<p>If you’re just here for speed, you’re done!
By passing a different seed to initialize each network with, this procedure will train many different networks at once in a way that’s useful for evaluating things like robustness to parameter initialization.
If you’d like to do a hyperparameter sweep, you can use different hyperparameters for each network in the network intitialization.
But if you’re interested in getting more out of this technique, in the next section I describe how to use parallel training to learn a bootstrapped ensemble for improved uncertainty calibration.</p>
<h2 id="bootstrapped-ensembles">Bootstrapped ensembles</h2>
<p>If we plot the predictions for two of the networks we just trained, we see something interesting:</p>
<div id="multi_mlp_chart" class="chart"></div>
<script src="/assets/js/multi_mlp_spec.js"></script>
<script>
var embedOpt = {"mode": "vega-lite"};
vegaEmbed("#multi_mlp_chart", spec, embedOpt);
</script>
<p>Try to spot the differences in these predictions.
Super hard, right?
Even though we’ve trained two different neural networks, because they were trained on the same data, their predictions are almost identical.
In other words, they’re all overfit to the same sample of data.</p>
<p>In this section we’ll see an application of parallel training to learning <em>bootstrapped ensembles</em> on a dataset.
<a href="https://en.wikipedia.org/wiki/Bootstrapping_%28statistics%29">Bootstrapping</a> is a way of measuring the uncertainty that comes from only having a small random sample from an underlying data distribution.
Instead of training multiple networks on exactly the same data, we’ll resample a “new” dataset for each network by sampling with replacement from the empirical distribution (AKA the training set).</p>
<p>Here’s how this works.
Say we have 100 points in our dataset.
For the first network, we create a new dataset by drawing 100 samples uniformly at random from our existing dataset.
We then train the network only on this newly-sampled dataset.
Then for the next network, we again draw 100 samples from the real training set, and we train our network #2 on dataset #2, and so forth for however many we want.
<!-- To make things reproducible, we use 0, the network's index, as a random seed. -->
<!-- For the next network, we again draw 100 samples with replacement. -->
In this way, each network is trained on a dataset that’s slightly different; it may have many copies of some points, and no copies of others.</p>
<p>Bootstrapping has the nice property that the randomness in the resample decreases as the dataset grows larger, capturing the way a model’s uncertainty should decrease as the training set more fully describes the data distribution.
For small datasets, having 0 versus 1 sample of a particular point in the training set will make a huge difference in the learned model.
For large datasets, there’s probably another point right next to that one anyway, so it doesn’t matter too much if you leave out (or double) any one point.</p>
<h3 id="a-bootstrapped-data-sampler">A bootstrapped data sampler</h3>
<p>The crucial change from our previous setup is that, on every training step, each network is going to get a different batch of data.
To implement this, we will write a bootstrapped sampler which we can <code class="language-plaintext highlighter-rouge">vmap</code>.
Each time we call the vmapped sampler, we’ll get back a <em>batch of batches</em> of shape <code class="language-plaintext highlighter-rouge">(number_of_networks, batch_size, data_size)</code>.
The layout will basically be</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[minibatch from dataset 0,
minibatch from dataset 1,
...,
minibatch from dataset K-1]
</code></pre></div></div>
<p>where dataset 0 is the first bootstrapped resample we took from our original dataset.
Ideally we want to implement this function without duplicating the entire dataset $K$ times in memory.</p>
<p>To do this, we will stop thinking directly about indices in the dataset, and start thinking about random seeds.
Imagine we want to sample one point uniformly at random from a dataset.
We could use a call like</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">point_index</span> <span class="o">=</span> <span class="n">jr</span><span class="p">.</span><span class="n">randint</span><span class="p">(</span><span class="n">jr</span><span class="p">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">point_seed</span><span class="p">),</span> <span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="c1"># jr = jax.random
</span> <span class="n">minval</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="n">dataset_size</span><span class="p">)</span>
</code></pre></div></div>
<p>to choose an index in our dataset to return.
This is a pure function of the <code class="language-plaintext highlighter-rouge">point_seed</code> that we use.
In mathematical terms, it is a <a href="https://en.wikipedia.org/wiki/Surjective_function">surjection</a> of the larger space of all possible random seeds onto the smaller space of integers in the interval [0, <code class="language-plaintext highlighter-rouge">dataset_size</code>].
Visually it looks something like this:</p>
<p><img src="assets/img/surjection.png" alt="Depiction of a surjection" /></p>
<p>This has a super useful property for us: since every individual seed corresponds to an index sampled uniformly at random, any <em>set</em> of $N$ seeds corresponds to a <em>bootstrapped sample</em> of indices of size $N$!
For example, if we were to take seeds 194-198, they would correspond to exactly the kind of bootstrapped resampling of our size-5 dataset that we want.
We would have 2 copies of data points 0 and 3, no copies of points 1 or 2, and one copy of point 4.</p>
<p>All we need to do to generate a bootstrapped resample of a dataset of size $N$ is to pick some seed $i$ to start from and use all the seeds in $[i, i + N - 1]$ to sample indices.
What we need is a function which maps from a <code class="language-plaintext highlighter-rouge">dataset_index</code>, which tells us which bootstrapped dataset we’re sampling from, to the first random seed which will be included in our resample.
Since I’m lazy, I’ll use the hash function in <code class="language-plaintext highlighter-rouge">jax.random.split</code>.<sup id="fnref:jr_split" role="doc-noteref"><a href="#fn:jr_split" class="footnote" rel="footnote">5</a></sup></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">get_first_seed</span><span class="p">(</span><span class="n">dataset_index</span><span class="p">):</span>
<span class="k">return</span> <span class="n">jr</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">jr</span><span class="p">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">dataset_index</span><span class="p">))[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span>
<span class="n">get_first_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="o">-></span> <span class="n">DeviceArray</span><span class="p">(</span><span class="mi">4146024105</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">uint32</span><span class="p">)</span>
</code></pre></div></div>
<p>Great! We got some giant random number.
This should be reassuring; the only way this could go wrong is if somehow we picked values of $i$ which were too close to each other.
This would make it so that e.g. dataset #1 had seeds $[i, i + n - 1]$, and dataset #2 had seeds $[i + 3, i + n - 2]$; since the seeds overlapped, the samples of those datasets would be too correlated.
The fact that we’re drawing from a really large space makes this astronomically unlikely.</p>
<p>Now we can use this function to implement a new function which will fetch us point $i$ from resample $k$:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">jax</span><span class="p">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">get_example</span><span class="p">(</span><span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span><span class="p">,</span> <span class="n">dataset_index</span><span class="p">,</span> <span class="n">i</span><span class="p">):</span>
<span class="s">"""Gets example `i` from the resample with index `dataset_index`."""</span>
<span class="n">first_seed</span> <span class="o">=</span> <span class="n">get_first_seed</span><span class="p">(</span><span class="n">dataset_index</span><span class="p">)</span>
<span class="n">dataset_size</span> <span class="o">=</span> <span class="n">data_x</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># only use dataset_size distinct seeds
</span> <span class="c1"># this makes sure that our bootstrap-sampled dataset includes exactly
</span> <span class="c1"># `dataset_size` points.
</span> <span class="n">i</span> <span class="o">=</span> <span class="n">i</span> <span class="o">%</span> <span class="n">dataset_size</span>
<span class="n">point_seed</span> <span class="o">=</span> <span class="n">first_seed</span> <span class="o">+</span> <span class="n">i</span>
<span class="n">point_index</span> <span class="o">=</span> <span class="n">jr</span><span class="p">.</span><span class="n">randint</span><span class="p">(</span><span class="n">jr</span><span class="p">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">point_seed</span><span class="p">),</span> <span class="n">shape</span><span class="o">=</span><span class="p">(),</span>
<span class="n">minval</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="n">dataset_size</span><span class="p">)</span>
<span class="c1"># equivalent to x_i = data_x[point_index]
</span> <span class="n">x_i</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">lax</span><span class="p">.</span><span class="n">dynamic_index_in_dim</span><span class="p">(</span><span class="n">data_x</span><span class="p">,</span> <span class="n">point_index</span><span class="p">,</span>
<span class="n">keepdims</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">y_i</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">lax</span><span class="p">.</span><span class="n">dynamic_index_in_dim</span><span class="p">(</span><span class="n">data_y</span><span class="p">,</span> <span class="n">point_index</span><span class="p">,</span>
<span class="n">keepdims</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x_i</span><span class="p">,</span> <span class="n">y_i</span>
</code></pre></div></div>
<p>And with this we can write a function which, given a dataset and a list of the bootstraps we want to sample from, gives us an iterator over batches-of-batches:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">bootstrap_multi_iterator</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">dataset_indices</span><span class="p">):</span>
<span class="s">"""Creates an iterator which, at each step, returns a batch of batches.
The kth batch is sampled from the bootstrapped resample of `dataset`
with seed `seeds[k]`."""</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">dataset_indices</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">dataset_indices</span><span class="p">)</span>
<span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span> <span class="o">=</span> <span class="n">dataset</span>
<span class="n">dataset_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">data_x</span><span class="p">)</span>
<span class="n">get_example_from_dataset</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">partial</span><span class="p">(</span><span class="n">get_example</span><span class="p">,</span> <span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span><span class="p">)</span>
<span class="c1"># for sampling a batch of data from one dataset
</span> <span class="n">get_batch</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">get_example_from_dataset</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="bp">None</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="c1"># for sampling a batch of data from _each_ dataset
</span> <span class="n">get_multibatch</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">get_batch</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">None</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">iterate_multibatch</span><span class="p">():</span>
<span class="s">"""Construct an iterator which runs forever, at each step returning
a batch of batches."""</span>
<span class="n">i</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">while</span> <span class="bp">True</span><span class="p">:</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">jnp</span><span class="p">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">yield</span> <span class="n">get_multibatch</span><span class="p">(</span><span class="n">dataset_indices</span><span class="p">,</span> <span class="n">indices</span><span class="p">)</span>
<span class="n">i</span> <span class="o">+=</span> <span class="n">batch_size</span>
<span class="n">loader_iter</span> <span class="o">=</span> <span class="n">iterate_multibatch</span><span class="p">()</span>
<span class="k">return</span> <span class="n">loader_iter</span>
</code></pre></div></div>
<h3 id="training-the-bootstrapped-ensemble">Training the bootstrapped ensemble</h3>
<p>Thanks to the flexibility of the JAX API, switching from training $K$ networks on one batch of data to training $K$ networks on $K$ batches of data is super simple.
We can change one argument to our <code class="language-plaintext highlighter-rouge">vmap</code> of <code class="language-plaintext highlighter-rouge">train_step_fn</code>, construct our iterator, and we’re ready to go!</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># same as before
</span><span class="n">parallel_init_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">init_fn</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="bp">None</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="c1"># vmap over both inputs now: model state AND batch of data
</span><span class="n">bootstrap_train_step_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">train_step_fn</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="c1"># make seeds 0 to N-1, which we use for initializing the network and bootstrapping
</span><span class="n">N</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">seeds</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">N</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">model_states</span> <span class="o">=</span> <span class="n">parallel_init_fn</span><span class="p">((</span><span class="mi">2</span><span class="p">,),</span> <span class="n">seeds</span><span class="p">)</span>
<span class="n">data_iterator</span> <span class="o">=</span> <span class="n">bootstrap_multi_iterator</span><span class="p">((</span><span class="n">points</span><span class="p">,</span> <span class="n">labels</span><span class="p">),</span> <span class="n">dataset_indices</span><span class="o">=</span><span class="n">seeds</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
<span class="n">x_batch</span><span class="p">,</span> <span class="n">y_batch</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="n">data_iterator</span><span class="p">)</span>
<span class="n">model_states</span><span class="p">,</span> <span class="n">losses</span> <span class="o">=</span> <span class="n">bootstrap_train_step_fn</span><span class="p">(</span><span class="n">model_states</span><span class="p">,</span> <span class="p">(</span><span class="n">x_batch</span><span class="p">,</span> <span class="n">y_batch</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span>
<span class="o">-></span> <span class="p">[</span><span class="mf">0.14846763</span> <span class="mf">0.09306543</span> <span class="mf">0.24074371</span> <span class="mf">0.26202717</span> <span class="mf">0.26234168</span> <span class="mf">0.18515839</span>
<span class="mf">0.10521372</span> <span class="mf">0.11991201</span> <span class="mf">0.1059431</span> <span class="mf">0.10932036</span><span class="p">]</span>
</code></pre></div></div>
<p>Naturally, this still takes only about a second.
Visualizing the predictions of a couple of the bootstrapped networks shows that they’re way more different than when we trained each one on the whole dataset.</p>
<div id="bootstrap_mlp_chart" class="chart"></div>
<script src="/assets/js/bootstrap_mlp_spec.js"></script>
<script>
var embedOpt = {"mode": "vega-lite"};
vegaEmbed("#bootstrap_mlp_chart", spec, embedOpt);
</script>
<p>I’ve plotted the entire dataset on top of the predictions, even though each network may not have seen all those points.
In particular, in the plot on the right it seems likely that the points which are currently misclassified were not in that network’s training set.</p>
<p>The exciting thing we can use these bootstrapped networks for is uncertainty quantification.
Since each network saw a different sample of the data, together their predictions can tell us what parts of the space are sure to be classified one way and which are more dependent on noise in your training set sample.
To do this, I simply average the probabilities I get from each network:</p>
\[p_\text{bootstrap}(y \mid x) = \frac{1}{K} \sum_{k=1}^K p(y \mid x; \theta_k)\]
<p>The results are really striking when we compare against the single network we trained on all the data before (shown at left):</p>
<div id="bootstrap_compare_chart" class="chart"></div>
<script src="/assets/js/bootstrap_compare_spec.js"></script>
<script>
var embedOpt = {"mode": "vega-lite"};
vegaEmbed("#bootstrap_compare_chart", spec, embedOpt);
</script>
<p>The single network predicts labels for almost the entire space with <em>absolute</em> confidence, even though it was only trained on 100 points.
By contrast, the bootstrapped ensemble (right) does a much better job of being uncertain near the boundary between the classes.</p>
<h2 id="conclusion">Conclusion</h2>
<p>Practically anytime you’re training a neural network, you would rather train several networks.
Whether you’re running multiple random seeds to make sure your results are reproducible, sweeping over learning rates to get the best results, or (as shown here) ensembling to improve calibration, there’s always <em>something</em> useful you could do with more runs.
By parallelizing training with JAX, you can run large numbers of small-scale experiments lightning fast.</p>
<p><strong>Citing</strong></p>
<p>If this blog post was useful to your research, you can cite it using</p>
<div class="language-bib highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nc">@misc</span><span class="p">{</span><span class="nl">Whitney2021Parallelizing</span><span class="p">,</span>
<span class="na">author</span> <span class="p">=</span> <span class="s">{William F. Whitney}</span><span class="p">,</span>
<span class="na">title</span> <span class="p">=</span> <span class="s">{ {Parallelizing neural networks on one GPU with JAX} }</span><span class="p">,</span>
<span class="na">year</span> <span class="p">=</span> <span class="s">{2021}</span><span class="p">,</span>
<span class="na">url</span> <span class="p">=</span> <span class="s">{http://willwhitney.com/parallel-training-jax.html}</span><span class="p">,</span>
<span class="p">}</span>
</code></pre></div></div>
<p><strong>Acknowledgements</strong></p>
<p>Thanks to Tegan Maharaj and David Brandfonbrener for reading drafts of this article and providing helpful feedback.
The JAX community was instrumental in helping me figure all of this stuff out, especially <a href="https://twitter.com/SingularMattrix">Matt Johnson</a>, <a href="https://twitter.com/avitaloliver">Avital Oliver</a>, and <a href="https://twitter.com/anselmlevskaya">Anselm Levskaya</a>.
Thanks are also due to my co-authors on our <a href="https://arxiv.org/abs/2009.07368">representation evaluation paper</a>, including Min Jae Song, David Brandfonbrener (again), Jaan Altosaar, and my advisor Kyunghyun Cho.</p>
<hr />
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:flop_counter" role="doc-endnote">
<p>Thanks to the excellent <a href="https://github.com/sovrasov/flops-counter.pytorch">FLOPs counter</a> by <a href="https://github.com/sovrasov">sovrasov</a>.</p>
</li>
<li id="fn:keskar" role="doc-endnote">
<p>Keskar, N., Mudigere, D., Nocedal, J., Smelyanskiy, M., & Tang, P. (2017). On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima. ArXiv, abs/1609.04836.</p>
</li>
<li id="fn:bottou" role="doc-endnote">
<p>Bottou, L., & Bousquet, O. (2007). The Tradeoffs of Large Scale Learning. Neural Information Processing Systems.</p>
</li>
<li id="fn:jit_time" role="doc-endnote">
<p>To accurately measure how long a compiled JAX function, like our <code class="language-plaintext highlighter-rouge">parallel_train_step_fn</code>, takes to run, we actually need to run it twice. On the first run for any set of input sizes it will spend a while compiling the function. On the second run we can measure how long the computation really takes. If you’re trying this at home, make sure to do a dry-run when you change <code class="language-plaintext highlighter-rouge">N</code> to see the real speed.</p>
</li>
<li id="fn:jr_split" role="doc-endnote">
<p>The function <code class="language-plaintext highlighter-rouge">jr.split</code>, or <code class="language-plaintext highlighter-rouge">jax.random.split</code>, takes an RNG key as an argument, and gives back a new one.</p>
</li>
</ol>
</div>Representation quality and the complexity of learning2020-09-24T00:00:00+00:002020-09-24T00:00:00+00:00http://willwhitney.com/representation-quality-and-the-complexity-of-learning<!-- <h1>Representation quality and the complexity of learning</h1> -->
<p><em>Cross-posted from the <a href="https://wp.nyu.edu/cilvr/2020/09/24/representation-quality-and-the-complexity-of-learning/">CILVR blog</a>.</em></p>
<p>In the last few years, there's been an explosion of work on learning good representations of data.
From NLP<sup><a href="#fn1-28786" id="fnr1-28786" title="see footnote" class="footnote">1</a></sup><sup><a href="#fn2-28786" id="fnr2-28786" title="see footnote" class="footnote">2</a></sup><sup><a href="#fn3-28786" id="fnr3-28786" title="see footnote" class="footnote">3</a></sup> to computer
vision<sup><a href="#fn4-28786" id="fnr4-28786" title="see footnote" class="footnote">4</a></sup><sup><a href="#fn5-28786" id="fnr5-28786" title="see footnote" class="footnote">5</a></sup><sup><a href="#fn6-28786" id="fnr6-28786" title="see footnote" class="footnote">6</a></sup> to reinforcement
learning<sup><a href="#fn7-28786" id="fnr7-28786" title="see footnote" class="footnote">7</a></sup><sup><a href="#fn8-28786" id="fnr8-28786" title="see footnote" class="footnote">8</a></sup><sup><a href="#fn9-28786" id="fnr9-28786" title="see footnote" class="footnote">9</a></sup>, the field has never
been hotter.
However, defining precisely what we mean by a good representation can be tricky.
This has led to a somewhat ad-hoc approach to evaluation in the literature, with each paper choosing its own
measure or set of measures and a general sense that our evaluation methods aren't very robust.</p>
<blockquote class="twitter-tweet" style="margin: 0 auto; display: block">
<p lang="en" dir="ltr">Though I'm not a big fan of the evaluation protocol: linear classification on top of
unsupervised features learned on ImageNet.</p>— Oriol Vinyals (@OriolVinyalsML) <a href="https://twitter.com/OriolVinyalsML/status/1228368026933719040?ref_src=twsrc%5Etfw">February 14,
2020</a>
</blockquote>
<script async="" src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>
<p>In a recent paper, <a href="https://arxiv.org/abs/2009.07368">Evaluating representations by the complexity of
learning low-loss predictors</a><sup><a href="#fn10-28786" id="fnr10-28786" title="see footnote" class="footnote">10</a></sup>, we show that many notions of the quality of a representation for a task
can be expressed as a function of the <em>loss-data curve</em>.
This perspective allows us to see the limitations of existing measures and propose new ones that are more
robust.</p>
<p>We think that evaluation is crucially important in the field right now and we don't want the measures that we and
others have proposed to languish as purely theoretical exercises.
Since these measures (ours and others) aren't trivial to implement or to compute, we are releasing <a href="https://github.com/willwhitney/reprieve">a library called Reprieve</a> for representation evaluation
that aims to standardize the evaluation of representation quality.
Whether you're using the measures that we proposed or several others, and no matter what ML library you use, you
can evaluate representations with Reprieve.</p>
<script src="https://tarptaeya.github.io/repo-card/repo-card.js"></script>
<p>
<div class="repo-card" data-repo="willwhitney/reprieve" style="min-width: 300px; min-height: 115px; max-width: 600px; margin: 0 auto; background-image: url('/assets/img/reprieve_github.png'); background-position: center; background-size: contain; background-repeat: no-repeat">
</div>
</p>
<h2>Loss-data curves and existing measures</h2>
<p>The loss-data curve, with the size of the training set on the X axis and validation loss on the Y axis, describes
how an algorithm's performance varies based on the amount of training data it's given.
Intuitively, the curve for a representation that allows the algorithm to learn <em>efficiently</em> (with little
data) will lie to the left of the curve for a representation that makes learning less efficient.
Meanwhile a representation that contains more predictive information will lead to a curve that goes lower as the
training set size goes to infinity.</p>
<figure>
<img src="/assets/img/fig1.png" alt="Loss-data curves and representation quality measures. The red and blue curves are the result of using the same learning algorithm with two different representations of the data." />
<figcaption>Loss-data curves and representation quality measures. The red and blue curves are the result of
using the same learning algorithm with two different representations of the data.</figcaption>
</figure>
<p>On the loss-data curve we can graphically show the meaning of several existing evaluation measures for
representation quality (left panel).</p>
<p><strong>Validation accuracy</strong> with limited data (VA) is the simplest measure. VA corresponds to picking
some <span class="math">\(n\)</span> for the dataset size and looking only at a vertical slice of the loss-data
curve at that <span class="math">\(n\)</span>.</p>
<p><strong>Mutual information</strong> (MI) attempts to measure the quality of a representation by its mutual
information with the labels<sup><a href="#fn11-28786" id="fnr11-28786" title="see footnote" class="footnote">11</a></sup>. MI is equivalent to considering only the validation loss with infinite
training data.</p>
<p><strong>Minimum description length</strong> (MDL) is an interesting measure recently proposed by Voita et al.
(2020)<sup><a href="#fn12-28786" id="fnr12-28786" title="see footnote" class="footnote">12</a></sup>. Given a
fixed dataset, MDL measures the description length of the dataset's labels (the vector of all the Ys) given its
observations (the vector of all the Xs) according to a particular encoding scheme. In the <em>prequential</em>
or <em>online</em> coding scheme, a model is trained to predict <span class="math">\(p(Y^k \mid X^k)\)</span> on
a dataset of size <span class="math">\(k\)</span>, and then used to encode the <span class="math">\((k+1)^{\mathrm{th}}\)</span> point. MDL corresponds to the area under the loss-data curve up
to <span class="math">\(n\)</span>, the full size of the dataset.</p>
<p>An interesting feature of all these methods is that they depend on (or specify, for MI) a particular dataset
size.
This can be a bit tricky: how much data <em>should</em> an algorithm need to solve a new task?
Provide too little data and no representation will allow any learning, but provide too much and only asymptotic
loss will matter, not efficiency.</p>
<p>Instead, we will construct an evaluation procedure that measures a property of the <em>data distribution</em> and
the <em>learning algorithm</em>, not a particular dataset or dataset size.</p>
<h2>Surplus Description Length</h2>
<p>We're going to build on the MDL idea to make a measure of representation quality.
To do this, we measure the complexity of learning for a given data distribution and learning algorithm.
We have two main goals for this representation evaluation measure:</p>
<ol>
<li>It should measure a fundamental property of the data distribution and learning algorithm.</li>
<li>The measure shouldn't depend on a particular sample of a dataset from the data distribution, the size of the
dataset, or the order of the points.</li>
</ol>
<h3>Defining surplus description length</h3>
<p>To start with, imagine trying to efficiently encode a large number of samples of some random variable <span class="math">\(\mathbf{e}\)</span> which takes discrete values in <span class="math">\(\{1 \ldots
K\}\)</span> with probability <span class="math">\(p(\mathbf{e})\)</span>.
The best possible code for each sample leverages knowledge of the probability of observing that sample, and
assigns a code length of <span class="math">\(- \log p(e_i)\)</span> to each sampled value <span class="math">\(e_i\)</span>.
This results in an expected length per sample of
<span class="math">\[
\mathbb{E}_\mathbf{e} [\ell_p(\mathbf{e})] = \mathbb{E}_\mathbf{e} [- \log p(\mathbf{e})] = H(\mathbf{e})
\]</span>
where we use <span class="math">\(\ell_p\)</span> to denote the negative log-likelihood loss for the
distribution <span class="math">\(p\)</span>.
Intuitively, the entropy <span class="math">\(H(\mathbf{e})\)</span> represents the amount of randomness in <span class="math">\(\mathbf{e}\)</span>; if we know the outcome of
the event we need to encode ahead of time, <span class="math">\(H(\mathbf{e}) = 0\)</span> and we don't need to transmit anything at all.
</p>
<p>If instead <span class="math">\(\mathbf{e}\)</span> was encoded using some other distribution <span class="math">\(\hat p\)</span>, the expected length becomes <span class="math">\(H(\mathbf{e}) +
D_{\mathrm{KL}}(p~||~\hat p)\)</span>.
We call <span class="math">\(D_{\mathrm{KL}}(p~||~\hat p)\)</span> the <em>surplus description length</em> (SDL)
from encoding according to <span class="math">\(\hat p\)</span> instead of <span class="math">\(p\)</span>.<sup><a href="#fn13-9372" id="fnr13-9372" title="see footnote" class="footnote">13</a></sup>
We can also write it as
<span class="math">\[
\mathrm{SDL}(\hat p)
= D_{\mathrm{KL}}(p~||~\hat p)
= \mathbb{E}_{\mathbf{e} \sim p} \left[ \log p(\mathbf{e}) - \log \hat p(\mathbf{e}) \right]
\]</span>
to highlight how SDL measures only the extra entropy that comes from not having the correct model.</p>
<h3>SDL as a measure of representation quality</h3>
<p>As our model learns we get a new <span class="math">\(\hat p\)</span> at every training step.
Similarly to MDL with online codes<sup><a href="#fn12-28786" title="see footnote" class="footnote">12</a></sup>,
we measure the SDL of the learned model at each step and then sum them up.
Writing the expected loss of running algorithm <span class="math">\(\mathcal{A}\)</span> on a dataset with <span class="math">\(i\)</span> points as <span class="math">\(L(\mathcal{A}_\phi, i)\)</span>, the SDL measure of
representation quality is
<span class="math">\[
m_{\mathrm{SDL}}(\phi, \mathcal{D}, \mathcal{A}) = \sum_{i=1}^\infty \Big[ L(\mathcal{A}_\phi, i) -
H(\mathbf{Y} \mid \mathbf{X}) \Big].
\]</span></p>
<p>We show in the paper that MDL is a special case of SDL which assumes that the true distribution of <span class="math">\(\mathbf{Y} \mid \mathbf{X}\)</span> is a delta mass. That is to say, <span class="math">\(H(\mathbf{Y} \mid \mathbf{X}) = 0\)</span> and the labels have no randomness at all.
This leads to some odd properties with real data, which typically has noise.
MDL goes to infinity with the size of the dataset even for algorithms which learn the true data distribution,
which makes numbers hard to compare.
More worryingly, if we rank the quality of two representations using MDL, that ranking can (and in practice
does) switch as we change the dataset size.
That means our conclusions about which representation is better are totally dependent on how much data we have
to evaluate them!</p>
<p>Since in practice we don't know the true entropy of the data distribution, we also propose a version of the SDL
measure where we set some threshold <span class="math">\(\varepsilon\)</span> as a criterion for success instead
of using the true entropy of the data.
As long as <span class="math">\(\varepsilon > H(\mathbf{Y} \mid \mathbf{X})\)</span>, this still has most of
the same nice properties.
A good way to set <span class="math">\(\varepsilon\)</span> would be to run the learning algorithm on a large
amount of data using the raw representation of the data, then set <span class="math">\(\varepsilon\)</span> to
the loss of that model plus a small slack term for estimation error.</p>
<p>We also propose a simpler measure called <span class="math">\(\varepsilon\)</span> sample complexity, or <span class="math">\(\varepsilon\)</span>SC, which is the number of training points required for the expected loss
to drop below <span class="math">\(\varepsilon\)</span>.
For full details on that <a href="https://arxiv.org/abs/2009.07368">check out the paper</a>! </p>
<h2>Representation evaluation in practice</h2>
<p>With our tools in hand, we can examine some practical representations.
Looking first at MNIST, we compare using the raw pixels to using neural encoders pretrained on supervised CIFAR
classification or trained without supervision as a low-dimensional VAE on MNIST.</p>
<figure>
<img src="/assets/img/mnist_results.png" alt="Results on MNIST. Since SDL measures a property of the data distribution, not a particular dataset, its values don't change as the dataset grows." />
<figcaption>Results on MNIST. Since SDL measures a property of the data distribution, not a particular dataset,
its values don't change as the dataset grows.</figcaption>
</figure>
<p>As you can see from the loss-data curve (right), these representations perform very differently!
While the VAE representation allows the quickest learning at first, it makes achieving very low loss hard.
Meanwhile the CIFAR pretrained representation supports learning that's more efficient than raw pixels for any
loss.</p>
<p>Looking at the evaluation measures, we see that the existing measures like validation loss and MDL tend to switch
their rankings when larger datasets are used for evaluation.
Meanwhile SDL and <span class="math">\(\varepsilon\)</span>SC know when there isn't enough data available to
evaluate a representation, and once they make a judgement, it sticks.</p>
<p>To show that this phenomenon isn't just limited to vision tasks or small datasets, we also provide experiments on
a part of speech classification task using pretrained representations from ELMo<sup><a href="#fn2-28786" title="see footnote" class="footnote">2</a></sup>.
Just like on MNIST, validation loss and MDL make very different predictions with small evaluation datasets than
with large ones.</p>
<figure>
<img src="/assets/img/elmo_results.png" alt="Results on part of speech classification." />
<figcaption>Results on part of speech classification.</figcaption>
</figure>
<h2>Better representation evaluation for everyone</h2>
<p>Existing measures of representation quality, which are functions of a particular dataset rather than the data
distribution, can have some tricky behavior.
Whether you use our measures or not, we urge our fellow members of the representation learning community to
think carefully about the measures and procedures that you use to evaluate representations.</p>
<p><a href="https://github.com/willwhitney/reprieve">Reprieve</a>, our library for representation evaluation, is one
tool that we think can help.
By using the powerful program transformations provided by <a href="https://github.com/google/jax">JAX</a>,
Reprieve is able to train the ~100 or so small networks required to construct a loss-data curve in parallel on
one GPU in about two minutes.
From there it can compute all of the measures that we mentioned today.</p>
<p>We hope that by standardizing on one codebase for evaluation, we in the representation learning community can
move faster while producing results that are more comparable and more reproducible.
If Reprieve is missing a measure that you think is important, submit a pull request!</p>
<div class="page-break" style="page-break-before: always;"></div>
<div class="footnotes">
<hr />
<ol>
<li id="fn1-28786">
<p>Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova. <a href="https://arxiv.org/abs/1810.04805">BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding</a>. <a href="#fnr1-28786" title="return to article" class="reversefootnote">↩︎</a></p>
</li>
<li id="fn2-28786">
<p>Matthew E. Peters, Mark Neumann, Mohit Iyyer, Matt Gardner, Christopher Clark, Kenton Lee, Luke
Zettlemoyer. <a href="https://arxiv.org/abs/1802.05365">Deep contextualized word
representations</a>. <a href="#fnr2-28786" title="return to article" class="reversefootnote">↩︎</a></p>
</li>
<li id="fn3-28786">
<p>Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke
Zettlemoyer, Veselin Stoyanov. <a href="https://arxiv.org/abs/1907.11692">RoBERTa: A Robustly
Optimized BERT Pretraining Approach</a>. <a href="#fnr3-28786" title="return to article" class="reversefootnote">↩︎</a></p>
</li>
<li id="fn4-28786">
<p>Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton. <a href="https://arxiv.org/abs/2002.05709">A Simple Framework for Contrastive Learning of Visual
Representations</a>. <a href="#fnr4-28786" title="return to article" class="reversefootnote">↩︎</a></p>
</li>
<li id="fn5-28786">
<p>Aaron van den Oord, Yazhe Li, Oriol Vinyals. <a href="https://arxiv.org/abs/1807.03748">Representation Learning with Contrastive Predictive
Coding</a>. <a href="#fnr5-28786" title="return to article" class="reversefootnote">↩︎</a></p>
</li>
<li id="fn6-28786">
<p>Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, Ross Girshick. <a href="https://arxiv.org/abs/1911.05722">Momentum Contrast for Unsupervised Visual Representation
Learning</a>. <a href="#fnr6-28786" title="return to article" class="reversefootnote">↩︎</a></p>
</li>
<li id="fn7-28786">
<p>Aravind Srinivas, Michael Laskin, Pieter Abbeel. <a href="https://arxiv.org/abs/2004.04136">CURL:
Contrastive Unsupervised Representations for Reinforcement Learning</a>. <a href="#fnr7-28786" title="return to article" class="reversefootnote">↩︎</a></p>
</li>
<li id="fn8-28786">
<p>Carles Gelada, Saurabh Kumar, Jacob Buckman, Ofir Nachum, Marc G. Bellemare. <a href="https://arxiv.org/abs/1906.02736">DeepMDP: Learning Continuous Latent Space Models for
Representation Learning</a>. <a href="#fnr8-28786" title="return to article" class="reversefootnote">↩︎</a></p>
</li>
<li id="fn9-28786">
<p>Amy Zhang, Rowan McAllister, Roberto Calandra, Yarin Gal, Sergey Levine. <a href="https://arxiv.org/abs/2006.10742">Learning Invariant Representations for Reinforcement
Learning without Reconstruction</a>. <a href="#fnr9-28786" title="return to article" class="reversefootnote">↩︎</a></p>
</li>
<li id="fn10-28786">
<p>William F. Whitney, Min Jae Song, David Brandfonbrener, Jaan Altosaar, Kyunghyun Cho. <a href="https://arxiv.org/abs/2009.07368">Evaluating representations by the complexity of learning
low-loss predictors</a>. <a href="#fnr10-28786" title="return to article" class="reversefootnote">↩︎</a></p>
</li>
<li id="fn11-28786">
<p>Note that to actually measure the mutual information between the random variables of the
representation and data requires arbitrarily large models, infinite data, and unbounded computation.
Mutual information is not a nice quantity to compute with. <a href="#fnr11-28786" title="return to article" class="reversefootnote">↩︎</a></p>
</li>
<li id="fn12-28786">
<p>Elena Voita, Ivan Titov. <a href="https://arxiv.org/abs/2003.12298">Information-Theoretic Probing
with Minimum Description Length</a>. <a href="#fnr12-28786" title="return to article" class="reversefootnote">↩︎</a></p>
</li>
<li id="fn13-9372">
<p>Encoding using the wrong distribution means that some event which happens often must have gotten a long code, and
in exchange some uncommon event got a short code.
It's as if someone made up a new language that made "the" 8 letters long and "eggplant" only
3; it would be convenient once a week when you type "eggplant", but really annoying the 100 times a
day you type "the". <a href="#fnr13-9372" title="return to article" class="reversefootnote">↩︎</a></p>
</li>
</ol>
</div>Cross-posted from the CILVR blog. In the last few years, there's been an explosion of work on learning good representations of data. From NLP123 to computer vision456 to reinforcement learning789, the field has never been hotter. However, defining precisely what we mean by a good representation can be tricky. This has led to a somewhat ad-hoc approach to evaluation in the literature, with each paper choosing its own measure or set of measures and a general sense that our evaluation methods aren't very robust. Though I'm not a big fan of the evaluation protocol: linear classification on top of unsupervised features learned on ImageNet.— Oriol Vinyals (@OriolVinyalsML) February 14, 2020 In a recent paper, Evaluating representations by the complexity of learning low-loss predictors10, we show that many notions of the quality of a representation for a task can be expressed as a function of the loss-data curve. This perspective allows us to see the limitations of existing measures and propose new ones that are more robust. We think that evaluation is crucially important in the field right now and we don't want the measures that we and others have proposed to languish as purely theoretical exercises. Since these measures (ours and others) aren't trivial to implement or to compute, we are releasing a library called Reprieve for representation evaluation that aims to standardize the evaluation of representation quality. Whether you're using the measures that we proposed or several others, and no matter what ML library you use, you can evaluate representations with Reprieve. Loss-data curves and existing measures The loss-data curve, with the size of the training set on the X axis and validation loss on the Y axis, describes how an algorithm's performance varies based on the amount of training data it's given. Intuitively, the curve for a representation that allows the algorithm to learn efficiently (with little data) will lie to the left of the curve for a representation that makes learning less efficient. Meanwhile a representation that contains more predictive information will lead to a curve that goes lower as the training set size goes to infinity. Loss-data curves and representation quality measures. The red and blue curves are the result of using the same learning algorithm with two different representations of the data. On the loss-data curve we can graphically show the meaning of several existing evaluation measures for representation quality (left panel). Validation accuracy with limited data (VA) is the simplest measure. VA corresponds to picking some \(n\) for the dataset size and looking only at a vertical slice of the loss-data curve at that \(n\). Mutual information (MI) attempts to measure the quality of a representation by its mutual information with the labels11. MI is equivalent to considering only the validation loss with infinite training data. Minimum description length (MDL) is an interesting measure recently proposed by Voita et al. (2020)12. Given a fixed dataset, MDL measures the description length of the dataset's labels (the vector of all the Ys) given its observations (the vector of all the Xs) according to a particular encoding scheme. In the prequential or online coding scheme, a model is trained to predict \(p(Y^k \mid X^k)\) on a dataset of size \(k\), and then used to encode the \((k+1)^{\mathrm{th}}\) point. MDL corresponds to the area under the loss-data curve up to \(n\), the full size of the dataset. An interesting feature of all these methods is that they depend on (or specify, for MI) a particular dataset size. This can be a bit tricky: how much data should an algorithm need to solve a new task? Provide too little data and no representation will allow any learning, but provide too much and only asymptotic loss will matter, not efficiency. Instead, we will construct an evaluation procedure that measures a property of the data distribution and the learning algorithm, not a particular dataset or dataset size. Surplus Description Length We're going to build on the MDL idea to make a measure of representation quality. To do this, we measure the complexity of learning for a given data distribution and learning algorithm. We have two main goals for this representation evaluation measure: It should measure a fundamental property of the data distribution and learning algorithm. The measure shouldn't depend on a particular sample of a dataset from the data distribution, the size of the dataset, or the order of the points. Defining surplus description length To start with, imagine trying to efficiently encode a large number of samples of some random variable \(\mathbf{e}\) which takes discrete values in \(\{1 \ldots K\}\) with probability \(p(\mathbf{e})\). The best possible code for each sample leverages knowledge of the probability of observing that sample, and assigns a code length of \(- \log p(e_i)\) to each sampled value \(e_i\). This results in an expected length per sample of \[ \mathbb{E}_\mathbf{e} [\ell_p(\mathbf{e})] = \mathbb{E}_\mathbf{e} [- \log p(\mathbf{e})] = H(\mathbf{e}) \] where we use \(\ell_p\) to denote the negative log-likelihood loss for the distribution \(p\). Intuitively, the entropy \(H(\mathbf{e})\) represents the amount of randomness in \(\mathbf{e}\); if we know the outcome of the event we need to encode ahead of time, \(H(\mathbf{e}) = 0\) and we don't need to transmit anything at all. If instead \(\mathbf{e}\) was encoded using some other distribution \(\hat p\), the expected length becomes \(H(\mathbf{e}) + D_{\mathrm{KL}}(p~||~\hat p)\). We call \(D_{\mathrm{KL}}(p~||~\hat p)\) the surplus description length (SDL) from encoding according to \(\hat p\) instead of \(p\).13 We can also write it as \[ \mathrm{SDL}(\hat p) = D_{\mathrm{KL}}(p~||~\hat p) = \mathbb{E}_{\mathbf{e} \sim p} \left[ \log p(\mathbf{e}) - \log \hat p(\mathbf{e}) \right] \] to highlight how SDL measures only the extra entropy that comes from not having the correct model. SDL as a measure of representation quality As our model learns we get a new \(\hat p\) at every training step. Similarly to MDL with online codes12, we measure the SDL of the learned model at each step and then sum them up. Writing the expected loss of running algorithm \(\mathcal{A}\) on a dataset with \(i\) points as \(L(\mathcal{A}_\phi, i)\), the SDL measure of representation quality is \[ m_{\mathrm{SDL}}(\phi, \mathcal{D}, \mathcal{A}) = \sum_{i=1}^\infty \Big[ L(\mathcal{A}_\phi, i) - H(\mathbf{Y} \mid \mathbf{X}) \Big]. \] We show in the paper that MDL is a special case of SDL which assumes that the true distribution of \(\mathbf{Y} \mid \mathbf{X}\) is a delta mass. That is to say, \(H(\mathbf{Y} \mid \mathbf{X}) = 0\) and the labels have no randomness at all. This leads to some odd properties with real data, which typically has noise. MDL goes to infinity with the size of the dataset even for algorithms which learn the true data distribution, which makes numbers hard to compare. More worryingly, if we rank the quality of two representations using MDL, that ranking can (and in practice does) switch as we change the dataset size. That means our conclusions about which representation is better are totally dependent on how much data we have to evaluate them! Since in practice we don't know the true entropy of the data distribution, we also propose a version of the SDL measure where we set some threshold \(\varepsilon\) as a criterion for success instead of using the true entropy of the data. As long as \(\varepsilon > H(\mathbf{Y} \mid \mathbf{X})\), this still has most of the same nice properties. A good way to set \(\varepsilon\) would be to run the learning algorithm on a large amount of data using the raw representation of the data, then set \(\varepsilon\) to the loss of that model plus a small slack term for estimation error. We also propose a simpler measure called \(\varepsilon\) sample complexity, or \(\varepsilon\)SC, which is the number of training points required for the expected loss to drop below \(\varepsilon\). For full details on that check out the paper! Representation evaluation in practice With our tools in hand, we can examine some practical representations. Looking first at MNIST, we compare using the raw pixels to using neural encoders pretrained on supervised CIFAR classification or trained without supervision as a low-dimensional VAE on MNIST. Results on MNIST. Since SDL measures a property of the data distribution, not a particular dataset, its values don't change as the dataset grows. As you can see from the loss-data curve (right), these representations perform very differently! While the VAE representation allows the quickest learning at first, it makes achieving very low loss hard. Meanwhile the CIFAR pretrained representation supports learning that's more efficient than raw pixels for any loss. Looking at the evaluation measures, we see that the existing measures like validation loss and MDL tend to switch their rankings when larger datasets are used for evaluation. Meanwhile SDL and \(\varepsilon\)SC know when there isn't enough data available to evaluate a representation, and once they make a judgement, it sticks. To show that this phenomenon isn't just limited to vision tasks or small datasets, we also provide experiments on a part of speech classification task using pretrained representations from ELMo2. Just like on MNIST, validation loss and MDL make very different predictions with small evaluation datasets than with large ones. Results on part of speech classification. Better representation evaluation for everyone Existing measures of representation quality, which are functions of a particular dataset rather than the data distribution, can have some tricky behavior. Whether you use our measures or not, we urge our fellow members of the representation learning community to think carefully about the measures and procedures that you use to evaluate representations. Reprieve, our library for representation evaluation, is one tool that we think can help. By using the powerful program transformations provided by JAX, Reprieve is able to train the ~100 or so small networks required to construct a loss-data curve in parallel on one GPU in about two minutes. From there it can compute all of the measures that we mentioned today. We hope that by standardizing on one codebase for evaluation, we in the representation learning community can move faster while producing results that are more comparable and more reproducible. If Reprieve is missing a measure that you think is important, submit a pull request! Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. ↩︎ Matthew E. Peters, Mark Neumann, Mohit Iyyer, Matt Gardner, Christopher Clark, Kenton Lee, Luke Zettlemoyer. Deep contextualized word representations. ↩︎ Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. RoBERTa: A Robustly Optimized BERT Pretraining Approach. ↩︎ Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton. A Simple Framework for Contrastive Learning of Visual Representations. ↩︎ Aaron van den Oord, Yazhe Li, Oriol Vinyals. Representation Learning with Contrastive Predictive Coding. ↩︎ Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, Ross Girshick. Momentum Contrast for Unsupervised Visual Representation Learning. ↩︎ Aravind Srinivas, Michael Laskin, Pieter Abbeel. CURL: Contrastive Unsupervised Representations for Reinforcement Learning. ↩︎ Carles Gelada, Saurabh Kumar, Jacob Buckman, Ofir Nachum, Marc G. Bellemare. DeepMDP: Learning Continuous Latent Space Models for Representation Learning. ↩︎ Amy Zhang, Rowan McAllister, Roberto Calandra, Yarin Gal, Sergey Levine. Learning Invariant Representations for Reinforcement Learning without Reconstruction. ↩︎ William F. Whitney, Min Jae Song, David Brandfonbrener, Jaan Altosaar, Kyunghyun Cho. Evaluating representations by the complexity of learning low-loss predictors. ↩︎ Note that to actually measure the mutual information between the random variables of the representation and data requires arbitrarily large models, infinite data, and unbounded computation. Mutual information is not a nice quantity to compute with. ↩︎ Elena Voita, Ivan Titov. Information-Theoretic Probing with Minimum Description Length. ↩︎ Encoding using the wrong distribution means that some event which happens often must have gotten a long code, and in exchange some uncommon event got a short code. It's as if someone made up a new language that made "the" 8 letters long and "eggplant" only 3; it would be convenient once a week when you type "eggplant", but really annoying the 100 times a day you type "the". ↩︎