Skip to content

Tutorial de JAX para entrenar redes neuronales artificiales

License

Notifications You must be signed in to change notification settings

phuijse/tutorial_jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Entrenando redes neuronales profundas utilizando el framework JAX

Este tutorial vive en: https://phuijse.github.io/tutorial_jax/

Abstract: JAX es un framework que permite auto-diferenciar y acelerar código escrito en Python, con enfoque en rutinas de álgebra lineal. Adicionalmente permite la compilación transparente a GPU y TPU, que sumado a lo anterior lo posiciona como un nuevo estándar para la computación científica y el entrenamiento de modelos. En este tutorial aprenderemos sobre el ecosistema de librerías basadas en JAX para hacer entrenamiento eficiente de redes neuronales profundas. Revisaremos las generalidades de JAX así como también la librería flax para definir modelos neuronales, optax para optimizar funciones de costo y numpyro para hacer programación probabilística. Pondremos todo lo anterior en práctica en base a un ejemplo con datos de series de tiempo astronómicas.

Autor: Pablo Huijse, pablo dot huijse at uach dot cl

Este tutorial fue presentado en la Escuela de Verano IEEE en Inteligencia Computacional (EVIC) 2022, organizada por la UTEM y el capítulo chileno de IEEE CIS: https://cis.ieeechile.cl/

Si quieres aprender más sobre computación científica, machine learning, redes neuronales e inferencia bayesiana puedes visitar: