Skip to content

LuxDL/Lux.jl

Repository files navigation

Join the chat at https://julialang.zulipchat.com #machine-learning Join the chat at https://julialang.org/slack #machine-learning Latest Docs Stable Docs

CI Build status codecov Benchmarks

Downloads Downloads

Aqua QA ColPrac: Contributor's Guide on Collaborative Practices for Community Packages SciML Code Style

Elegant & Performant Scientific Machine Learning in Julia

A Pure Julia Deep Learning Framework designed for Scientific Machine Learning

💻 Installation

import Pkg
Pkg.add("Lux")

🤸 Quickstart

using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support

# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)

# Construct the layer
model = Chain(BatchNorm(128), Dense(128, 256, tanh), BatchNorm(256),
              Chain(Dense(256, 1, tanh), Dense(1, 10)))

# Get the device determined by Lux
device = gpu_device()

# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> device

# Dummy Input
x = rand(rng, Float32, 128, 2) |> device

# Run the model
y, st = Lux.apply(model, x, ps, st)

# Gradients
gs = only(gradient(p -> sum(first(Lux.apply(model, x, p, st))), ps))

# Optimization
st_opt = Optimisers.setup(Optimisers.Adam(0.0001), ps)
st_opt, ps = Optimisers.update(st_opt, ps, gs)

📚 Examples

Look in the examples directory for self-contained usage examples. The documentation has examples sorted into proper categories.

🧪 Testing

The full test of Lux.jl takes a long time, here's how to test a portion of the code.

For each @testitem, there are corresponding tags, for example:

@testitem "SkipConnection" setup=[SharedTestSetup] tags=[:core_layers]

For example, let's consider the tests for SkipConnection:

@testitem "SkipConnection" setup=[SharedTestSetup] tags=[:core_layers] begin
    ...
end

We can test the group to which SkipConnection belongs by testing core_layers. To do so set the LUX_TEST_GROUP environment variable, or rename the tag to further narrow the test scope:

export LUX_TEST_GROUP="core_layers"

Or directly modify the default test tag in runtests.jl:

# const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "all"))
const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "core_layers"))

But be sure to restore the default value "all" before submitting the code.

Furthermore if you want to run a specific test based on the name of the testset, you can use TestEnv.jl as follows. Start with activating the Lux environment and then run the following:

using TestEnv; TestEnv.activate(); using ReTestItems;

# Assuming you are in the main directory of Lux
ReTestItems.runtests("tests/"; name = "NAME OF THE TEST")

For the SkipConnection tests that would be:

ReTestItems.runtests("tests/"; name = SkipConnection)

🆘 Getting Help

For usage related questions, please use Github Discussions or JuliaLang Discourse (machine learning domain) which allows questions and answers to be indexed. To report bugs use github issues or even better send in a pull request.

🧑‍🔬 Citation

If you found this library to be useful in academic work, then please cite:

@software{pal2023lux,
  author    = {Pal, Avik},
  title     = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}},
  month     = apr,
  year      = 2023,
  note      = {If you use this software, please cite it as below.},
  publisher = {Zenodo},
  version   = {v0.5.0},
  doi       = {10.5281/zenodo.7808904},
  url       = {https://doi.org/10.5281/zenodo.7808904}
}

@thesis{pal2023efficient,
  title     = {{On Efficient Training \& Inference of Neural Differential Equations}},
  author    = {Pal, Avik},
  year      = {2023},
  school    = {Massachusetts Institute of Technology}
}

Also consider starring our github repo