Skip to content

Commit

Permalink
Neural ODE tutorial with SimpleChains.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-1Bhatt committed Sep 24, 2022
1 parent 7ddbd47 commit 8d858cf
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand Down
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pages = [
"training_tips/divergence.md",
"training_tips/multiple_nn.md"],
"Neural Ordinary Differential Equation (Neural ODE) Tutorials" => Any["neural_ode/neural_ode_flux.md",
"neural_ode/simplechains.md",
"neural_ode/neural_gde.md",
"neural_ode/minibatch.md"],
"Stochastic Differential Equation (SDE) Tutorials" => Any["sde_fitting/optimization_sde.md"],
Expand Down
82 changes: 82 additions & 0 deletions docs/src/neural_ode/simplechains.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Neural Ordinary Differential Equations with SimpleChains

[SimpleChains](https://github.com/PumasAI/SimpleChains.jl) has demonstrated performance boosts of ~5x and ~30x when compared to other mainstream deep learning frameworks like Pytorch for the training and evaluation in the specific case of small neural networks. For the nitty-gritty details ,as well as, some SciML related videos around the need and applications of such a library we can refer to this [blogpost](https://julialang.org/blog/2022/04/simple-chains/).As for doing Scientific Machine Learning, how do we even begin with training neural ODEs with any generic deep learning library?

## Training Data

Firstly we'll need data for training the NeuralODE, which can be obtained by solving the ODE `u' = f(u,p,t)` numerically using the SciML ecosystem in Julia.

```@example sc_neuralode
using SimpleChains, StaticArrays, OrdinaryDiffEq, SciMLSensitivity, Optimization, OptimizationFlux, Plots
u0 = @SArray Float32[2.0, 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
function trueODE(u, p, t)
true_A = @SMatrix Float32[-0.1 2.0; -2.0 -0.1]
((u.^3)'true_A)'
end
prob = ODEProblem(trueODE, u0, tspan)
data = Array(solve(prob, Tsit5(), saveat = tsteps))
```

## Neural Network

Next we setup a small neural network. It will be trained to output the derivative of the solution at each time step given the value of the solution at the previous time step and the parameters of the network. Thus, we are treating the neural network as a function `f(u,p,t)`. The difference is that instead of relying on knowing the exact equation for the ODE, we get to solve it only with the data.

```@example sc_neuralode
sc = SimpleChain(
static(2),
Activation(x -> x.^3),
TurboDense{true}(tanh, static(50)),
TurboDense{true}(identity, static(2))
)
p_nn = SimpleChains.init_params(sc)
f(u,p,t) = sc(u,p)
```

## NeuralODE, Prediction and Loss

Now instead of the function `trueODE(u,p,t)` in the first code block, we pass the neural network to the ODE solver. This is our NeuralODE. Now in order to train it we obtain predictions from the model and calculate the L2 loss against the data generated numerically previously.

```@example sc_neuralode
prob_nn = ODEProblem(f, u0, tspan)
function predict_neuralode(p)
Array(solve(prob_nn, Tsit5();p=p,saveat=tsteps,sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())))
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, data .- pred)
return loss, pred
end
```

## Training

The next step is to minimize the loss, so that the NeuralODE gets trained. But in order to be able to do that, we have to be able to backpropagate through the NeuralODE model. Here the backpropagation through the neural network is the easy part and we get that out of the box with any deep learning package(although not as fast as SimpleChains for the small nn case here). But we have to find a way to first propagate the sensitivities of the loss back, first through the ODE solver and then to the neural network.

The adjoint of a neural ODE can be calculated through the various AD algorithms available in SciMLSensitivity.jl. But for working with [StaticArrays](https://github.com/JuliaArrays/StaticArrays.jl) in SimpleChains.jl we require a special adjoint method as StaticArrays do not allow any mutation. All the adjoint methods make heavy use of in-place mutation to be performant with the heap allocated normal arrays. For our statically sized, stack allocated StaticArrays, in order to be able to compute the ODE adjoint we need to do everything out of place. Hence we have specifically used `QuadratureAdjoint(autojacvec=ZygoteVJP())` adjoint algorithm in the solve call inside `predict_neuralode(p)` which computes everything out-of-place when u0 is a StaticArray. Hence we can move forward with the training of the NeuralODE

```@example sc_neuralode
callback = function (p, l, pred; doplot = true)
display(l)
plt = scatter(tsteps, data[1,:],label="data")
scatter!(plt, tsteps, pred[1,:], label = "prediction")
if doplot
display(plot(plt))
end
return false
end
optf = Optimization.OptimizationFunction((x,p)->loss_neuralode(x), Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optf, p_nn)
res = Optimization.solve(optprob, ADAM(0.05),callback=callback,maxiters=300)
```

0 comments on commit 8d858cf

Please sign in to comment.