far.in.net


Hi, JAX!

A short introduction to JAX for deep learning researchers.

This course runs from July 11, 2024 to September 12, 2024.

Anyone can join the course. Simply join the metauni Discord server (invite link) and introduce yourself in the #hijax channel. (Note: unlike other metauni events, this course will not use Roblox.)

Format

Weekly workshops: A 60 minute live code-along workshop on implementing a short ML-related project that requires understanding a new JAX concept. See the syllabus for details.

The code will be available on GitHub to help you follow each week’s workshop. Participants are invited to share their own solutions by maintaining a public fork of the course repository.

Optional homework: Participants are invited to complete additional projects outside of the workshops, including:

These are suggestions only, and participants are encouraged to pursue their own alternative project ideas and share their work.

Community: Participants are welcome to use the #hijax channel in the metauni Discord server to discuss course topics, workshops, and challenges. This channel will also be used for scheduling updates.

Prerequisites

Programming:

Theory:

You will need a Python environment (but not necessarily a GPU) if you want to code along during the workshops.

Syllabus

The course will cover the following nine topics. Workshops run on Thursdays at 2pm AEST on the listed dates.

  1. Hi, JAX! How’s life? (July 11) Intro to JAX, course overview, immutability and jax.numpy, randomness with jax.random. Demonstration: Elementary cellular automaton. Challenge: Conway’s game of life.

  2. Hi, automatic differentiation! (July 18) Call/init function API for models, jax.grad transformation. Demonstration: Classical perceptron trained with classical SGD, vanilla JAX. Challenge: Multi-layer perceptron, vanilla JAX.

  3. Hi, pytrees! (July 25) Pytrees, jax.tree.map, equinox modules. Demonstration: Train an MLP on MNIST with minibatch SGD. Challenge: Manually register the MLP modules as pytrees (obviating the equinox dependency).

  4. Hi, deep learning ecosystem! Hi, automatic vectorisation! (August 1) Modules with equinox.nn, vectorisation with jax.vmap, stateful optimisation with optax. Demonstration: Train a CNN on MNIST with minibatch SGD and Adam. Challenge: Implement a drop-in replacement for optax.adam.

  5. Hi, just-in-time compilation! (August 8) Compilation with jax.jit, tracing vs. execution, side-effects. Demonstration: JIT dojo (part 1), and train an accelerated CNN on MNIST. Challenge: Implement and train more historic network architectures.

  6. Hi again, just-in-time compilation! (August 15) Compile errors due to non-static shapes, static arguments and recompilation. Demonstration: Train a byte-transformer on the Sherlock Holmes canon. Challenge: Add dropout modules to the transformer.

  7. Hi, loop acceleration! (August 22) Looping computations with jax.lax.scan. Demonstration: Accelerate a whole training loop. Challenge: Vectorise a hyperparameter sweep and replicate some error rate from Yann LeCun’s table.

Note: No workshop on August 29.

  1. Hi, branching computation! (September 5) Stateful environment API, conditional computation with jax.lax.cond and jax.lax.select. Demonstration: Simple gridworld maze environment. Challenge: Determine solvability by implementing an accelerated DFS/BFS. Alternative challenge: Tabular Q-learning.

  2. Hi, deep reinforcement learning! (September 12) PPO algorithm, reverse scan, revision of previous topics. Demonstration: Accelerated PPO, solve small mazes. Challenge: Solve larger mazes.

Bounty board

The first person to complete each of the following tasks will receive a prize of one hexadecimal Australian dollar (2.56 AUD) and name recognition on this webpage.

  1. Hi, profiling tools! Learn how to use JAX profiling tools and find and remove a 2x performance bottleneck in the hijax GitHub repository.

  2. Hi, automatic parallelisation! Find and demonstrate a hardware/hyperparameter configuration where using jax.pmap yields a speedup of at least 10x.

  3. Hi, developmental interpretability! Accelerate and vectorise LLC estimation, and replicate your choice of figure from [1] or [2] (authors from these works are ineligible).

    Bounty 3 claimed by Rohan Hitchcock on 2024.08.07. [Rohan’s fork]

  4. Hi, mechanistic interpretability! Take an existing image model we have trained and then switch to optimising over the space of inputs (rather than parameters) to produce a feature visualisation along the lines of [3].

  5. Hi again, mechanistic interpretability! Take an existing image model we have trained and then train a sparse autoencoder (SAE) on it to produce a feature visualisation along the lines of [4].

Other JAX resources

If you want to learn more about JAX, here are some good resources to know about.

  1. The official JAX documentation offers beginner and advanced tutorials, advice on frequent issues, a detailed API reference, and more.

  2. The University of Amsterdam’s Deep Learning Course notebooks has tutorials covering basic JAX and various deep learning topics implemented in JAX.

  3. Awesome JAX is a GitHub repository maintaining a list of JAX libraries, learning resources, papers, blog posts, and more.