Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Isolate broken tests #786

Merged
merged 6 commits into from
Feb 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions docs/src/examples/dae/physical_constraints.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ zeros, then we have a constraint defined by the right-hand side. Using
terms must add to one. An example of this is as follows:

```@example dae
using Lux, DiffEqFlux, Optimization, OptimizationNLopt, DifferentialEquations, Plots
using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt,
DifferentialEquations, Plots

using Random
rng = Random.default_rng()
Expand Down Expand Up @@ -42,7 +43,7 @@ pinit, st = Lux.setup(rng, nn_dudt2)

model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
tspan, M, Rodas5(autodiff = false), saveat = 0.1)
model_stiff_ndae(u₀, Lux.ComponentArray(pinit), st)
model_stiff_ndae(u₀, ComponentArray(pinit), st)

function predict_stiff_ndae(p)
return model_stiff_ndae(u₀, p, st)[1]
Expand All @@ -59,11 +60,11 @@ end
# return false
# end

l1 = first(loss_stiff_ndae(Lux.ComponentArray(pinit)))
l1 = first(loss_stiff_ndae(ComponentArray(pinit)))

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(pinit))
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100)
```

Expand All @@ -72,7 +73,8 @@ result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100)
### Load Packages

```@example dae2
using Lux, DiffEqFlux, Optimization, OptimizationNLopt, DifferentialEquations, Plots
using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt,
DifferentialEquations, Plots

using Random
rng = Random.default_rng()
Expand Down Expand Up @@ -142,7 +144,7 @@ pinit, st = Lux.setup(rng, nn_dudt2)

model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
tspan, M, Rodas5(autodiff = false), saveat = 0.1)
model_stiff_ndae(u₀, Lux.ComponentArray(pinit), st)
model_stiff_ndae(u₀, ComponentArray(pinit), st)
```

Because this is a stiff problem, we have manually imposed that sum constraint via
Expand Down Expand Up @@ -176,7 +178,7 @@ function loss_stiff_ndae(p)
return loss, pred
end

l1 = first(loss_stiff_ndae(Lux.ComponentArray(pinit)))
l1 = first(loss_stiff_ndae(ComponentArray(pinit)))
```

Notice that we are feeding the **parameters** of `model_stiff_ndae` to the `loss_stiff_ndae`
Expand Down Expand Up @@ -206,6 +208,6 @@ Finally, training with `Optimization.solve` by passing: *loss function*, *model
```@example dae2
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(pinit))
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100)
```
6 changes: 3 additions & 3 deletions docs/src/examples/ode/exogenous_input.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ In the following example, a discrete exogenous input signal `ex` is defined and
used as an input into the neural network of a neural ODE system.

```@example exogenous
using DifferentialEquations, Lux, DiffEqFlux, Optimization, OptimizationPolyalgorithms,
OptimizationFlux, Plots, Random
using DifferentialEquations, Lux, ComponentArrays, DiffEqFlux, Optimization,
OptimizationPolyalgorithms, OptimizationFlux, Plots, Random

rng = Random.default_rng()
tspan = (0.1f0, Float32(10.0))
Expand Down Expand Up @@ -88,7 +88,7 @@ end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p_model))
optprob = Optimization.OptimizationProblem(optf, ComponentArray(p_model))

res0 = Optimization.solve(optprob, PolyOpt(), maxiters = 100)

Expand Down
5 changes: 3 additions & 2 deletions docs/src/tutorials/training_tips/local_minima.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ before, except with one small twist: we wish to find the neural ODE that fits
on `(0,5.0)`. Naively, we use the same training strategy as before:

```@example iterativefit
using DifferentialEquations, SciMLSensitivity, Optimization, OptimizationFlux
using DifferentialEquations, ComponentArrays, SciMLSensitivity, Optimization,
OptimizationFlux
using Lux, Plots, Random

rng = Random.default_rng()
Expand All @@ -38,7 +39,7 @@ dudt2 = Lux.Chain(ActivationFunction(x -> x .^ 3),
Lux.Dense(16, 2))

pinit, st = Lux.setup(rng, dudt2)
pinit = Lux.ComponentArray(pinit)
pinit = ComponentArray(pinit)

function neuralode_f(u, p, t)
dudt2(u, p, st)[1]
Expand Down
15 changes: 8 additions & 7 deletions docs/src/tutorials/training_tips/multiple_nn.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ this kind of study.
The following is a fully working demo on the Fitzhugh-Nagumo ODE:

```@example
using Lux, DiffEqFlux, Optimization, OptimizationNLopt, DifferentialEquations, Random
using Lux, DiffEqFlux, ComponentArrays, Optimization, OptimizationNLopt,
DifferentialEquations, Random

rng = Random.default_rng()
Random.seed!(rng, 1)
Expand Down Expand Up @@ -38,13 +39,13 @@ NN_2 = Lux.Chain(Lux.Dense(3, 16, tanh), Lux.Dense(16, 1))
p2, st2 = Lux.setup(rng, NN_2)
scaling_factor = 1.0f0

p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)
p1 = ComponentArray(p1)
p2 = ComponentArray(p2)

p = Lux.ComponentArray{eltype(p1)}()
p = Lux.ComponentArray(p; p1)
p = Lux.ComponentArray(p; p2)
p = Lux.ComponentArray(p; scaling_factor)
p = ComponentArray{eltype(p1)}()
p = ComponentArray(p; p1)
p = ComponentArray(p; p2)
p = ComponentArray(p; scaling_factor)

function dudt_(u, p, t)
v, w = u
Expand Down
3 changes: 2 additions & 1 deletion src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ import SciMLBase: unwrapped_f

import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAlgorithm,
AbstractForwardSensitivityAlgorithm, AbstractAdjointSensitivityAlgorithm,
AbstractSecondOrderSensitivityAlgorithm, AbstractShadowingSensitivityAlgorithm
AbstractSecondOrderSensitivityAlgorithm,
AbstractShadowingSensitivityAlgorithm

include("hasbranching.jl")
include("sensitivity_algorithms.jl")
Expand Down
6 changes: 3 additions & 3 deletions src/backsolve_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ end
# check if solution was terminated, then use reduced time span
terminated = false
if hasfield(typeof(sol), :retcode)
if sol.retcode == :Terminated
if sol.retcode == ReturnCode.Terminated
tspan = (tspan[1], sol.t[end])
terminated = true
end
Expand Down Expand Up @@ -253,7 +253,7 @@ end
# check if solution was terminated, then use reduced time span
terminated = false
if hasfield(typeof(sol), :retcode)
if sol.retcode == :Terminated
if sol.retcode == ReturnCode.Terminated
tspan = (tspan[1], sol.t[end])
terminated = true
end
Expand Down Expand Up @@ -370,7 +370,7 @@ end
# check if solution was terminated, then use reduced time span
terminated = false
if hasfield(typeof(sol), :retcode)
if sol.retcode == :Terminated
if sol.retcode == ReturnCode.Terminated
tspan = (tspan[1], sol.t[end])
terminated = true
end
Expand Down
6 changes: 3 additions & 3 deletions src/interpolating_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ end
# check if solution was terminated, then use reduced time span
terminated = false
if hasfield(typeof(sol), :retcode)
if sol.retcode == :Terminated
if sol.retcode == ReturnCode.Terminated
tspan = (tspan[1], sol.t[end])
terminated = true
end
Expand Down Expand Up @@ -402,7 +402,7 @@ end
# check if solution was terminated, then use reduced time span
terminated = false
if hasfield(typeof(sol), :retcode)
if sol.retcode == :Terminated
if sol.retcode == ReturnCode.Terminated
tspan = (tspan[1], sol.t[end])
terminated = true
end
Expand Down Expand Up @@ -542,7 +542,7 @@ end
# check if solution was terminated, then use reduced time span
terminated = false
if hasfield(typeof(sol), :retcode)
if sol.retcode == :Terminated
if sol.retcode == ReturnCode.Terminated
tspan = (tspan[1], sol.t[end])
terminated = true
end
Expand Down
12 changes: 6 additions & 6 deletions src/quadrature_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ end

terminated = false
if hasfield(typeof(sol), :retcode)
if sol.retcode == :Terminated
if sol.retcode == ReturnCode.Terminated
tspan = (tspan[1], sol.t[end])
terminated = true
end
Expand Down Expand Up @@ -141,11 +141,11 @@ end
original_mm = sol.prob.f.mass_matrix
if original_mm === I || original_mm === (I, I)
odefun = ODEFunction{ArrayInterface.ismutable(z0), true}(sense,
jac_prototype = adjoint_jac_prototype)
jac_prototype = adjoint_jac_prototype)
else
odefun = ODEFunction{ArrayInterface.ismutable(z0), true}(sense,
mass_matrix = sol.prob.f.mass_matrix',
jac_prototype = adjoint_jac_prototype)
mass_matrix = sol.prob.f.mass_matrix',
jac_prototype = adjoint_jac_prototype)
end
if RetCB
return ODEProblem(odefun, z0, tspan, p, callback = cb), rcb
Expand Down Expand Up @@ -359,12 +359,12 @@ function _adjoint_sensitivities(sol, sensealg::QuadratureAdjoint, alg; t = nothi
end

# correction for end interval.
if t[end] != sol.prob.tspan[2] && sol.retcode !== :Terminated
if t[end] != sol.prob.tspan[2] && sol.retcode !== ReturnCode.Terminated
res .+= quadgk(integrand, t[end], sol.prob.tspan[end],
atol = abstol, rtol = reltol)[1]
end

if sol.retcode === :Terminated
if sol.retcode === ReturnCode.Terminated
integrand = update_integrand_and_dgrad(res, sensealg, callback, integrand,
adj_prob, sol, dgdu_discrete,
dgdp_discrete, dλ, dgrad, t[end],
Expand Down
2 changes: 1 addition & 1 deletion test/complex_no_u.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using OrdinaryDiffEq, SciMLSensitivity, LinearAlgebra, Optimization, OptimizationFlux, Flux
nn = Chain(Dense(1, 16), Dense(16, 16, tanh), Dense(16, 2))
nn = Chain(Dense(1, 16), Dense(16, 16, tanh), Dense(16, 2)) |> f64
initial, re = Flux.destructure(nn)

function ode2!(u, p, t)
Expand Down
4 changes: 2 additions & 2 deletions test/partial_neural.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ prob = ODEProblem(dudt2_, x, tspan, _p)
solve(prob, Tsit5())

function predict_rd(θ)
Array(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:end], abstol = 1e-7, reltol = 1e-5,
sensealg = TrackerAdjoint()))
Array(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:end], abstol = 1.0f-7, reltol = 1.0f-5))
end

loss_rd(p) = sum(abs2, x - 1 for x in predict_rd(p))
l = loss_rd(θ)

Expand Down
18 changes: 2 additions & 16 deletions test/steady_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,7 @@ Random.seed!(12345)
@info "Calculate adjoint sensitivities from autodiff & numerical diff"
function G(p)
tmp_prob = remake(prob, u0 = convert.(eltype(p), prob.u0), p = p)
sol = solve(tmp_prob,
SSRootfind(nlsolve = (f!, u0, abstol) -> (res = NLsolve.nlsolve(f!,
u0,
autodiff = :forward,
method = :newton,
iterations = Int(1e6),
ftol = 1e-14);
res.zero)))
sol = solve(tmp_prob, DynamicSS(Rodas5()))
A = convert(Array, sol)
g(A, p, nothing)
end
Expand Down Expand Up @@ -259,14 +252,7 @@ Random.seed!(12345)
@testset "for u0: (should be zero, steady state does not depend on initial condition)" begin
res5 = ForwardDiff.gradient(prob.u0) do u0
tmp_prob = remake(prob, u0 = u0)
sol = solve(tmp_prob,
SSRootfind(nlsolve = (f!, u0, abstol) -> (res = NLsolve.nlsolve(f!,
u0,
autodiff = :forward,
method = :newton,
iterations = Int(1e6),
ftol = 1e-14);
res.zero)))
sol = solve(tmp_prob, DynamicSS(Rodas5()))
A = convert(Array, sol)
g(A, p, nothing)
end
Expand Down