Este tutorial vive en: https://phuijse.github.io/tutorial_jax/
Abstract: JAX es un framework que permite auto-diferenciar y acelerar código escrito en Python, con enfoque en rutinas de álgebra lineal. Adicionalmente permite la compilación transparente a GPU y TPU, que sumado a lo anterior lo posiciona como un nuevo estándar para la computación científica y el entrenamiento de modelos. En este tutorial aprenderemos sobre el ecosistema de librerías basadas en JAX para hacer entrenamiento eficiente de redes neuronales profundas. Revisaremos las generalidades de JAX así como también la librería flax para definir modelos neuronales, optax para optimizar funciones de costo y numpyro para hacer programación probabilística. Pondremos todo lo anterior en práctica en base a un ejemplo con datos de series de tiempo astronómicas.
Autor: Pablo Huijse, pablo dot huijse at uach dot cl
Este tutorial fue presentado en la Escuela de Verano IEEE en Inteligencia Computacional (EVIC) 2022, organizada por la UTEM y el capítulo chileno de IEEE CIS: https://cis.ieeechile.cl/
Si quieres aprender más sobre computación científica, machine learning, redes neuronales e inferencia bayesiana puedes visitar: