From 87464b2643a2794e753a346b37fb0d012e49cf69 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 24 Feb 2023 07:28:06 -0500 Subject: [PATCH 1/6] Isolate broken tests --- test/partial_neural.jl | 4 ++-- test/steady_state.jl | 22 +++++----------------- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/test/partial_neural.jl b/test/partial_neural.jl index 9de382bc9..9916eb5de 100644 --- a/test/partial_neural.jl +++ b/test/partial_neural.jl @@ -19,9 +19,9 @@ prob = ODEProblem(dudt2_, x, tspan, _p) solve(prob, Tsit5()) function predict_rd(θ) - Array(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:end], abstol = 1e-7, reltol = 1e-5, - sensealg = TrackerAdjoint())) + Array(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:end], abstol = 1f-7, reltol = 1f-5)) end + loss_rd(p) = sum(abs2, x - 1 for x in predict_rd(p)) l = loss_rd(θ) diff --git a/test/steady_state.jl b/test/steady_state.jl index 2a3012df4..a7f73faa6 100644 --- a/test/steady_state.jl +++ b/test/steady_state.jl @@ -73,14 +73,7 @@ Random.seed!(12345) @info "Calculate adjoint sensitivities from autodiff & numerical diff" function G(p) tmp_prob = remake(prob, u0 = convert.(eltype(p), prob.u0), p = p) - sol = solve(tmp_prob, - SSRootfind(nlsolve = (f!, u0, abstol) -> (res = NLsolve.nlsolve(f!, - u0, - autodiff = :forward, - method = :newton, - iterations = Int(1e6), - ftol = 1e-14); - res.zero))) + sol = solve(tmp_prob, DynamicSS(Rodas5())) A = convert(Array, sol) g(A, p, nothing) end @@ -259,14 +252,7 @@ Random.seed!(12345) @testset "for u0: (should be zero, steady state does not depend on initial condition)" begin res5 = ForwardDiff.gradient(prob.u0) do u0 tmp_prob = remake(prob, u0 = u0) - sol = solve(tmp_prob, - SSRootfind(nlsolve = (f!, u0, abstol) -> (res = NLsolve.nlsolve(f!, - u0, - autodiff = :forward, - method = :newton, - iterations = Int(1e6), - ftol = 1e-14); - res.zero))) + sol = solve(tmp_prob, DynamicSS(Rodas5())) A = convert(Array, sol) g(A, p, nothing) end @@ -437,6 +423,7 @@ end @test dp1≈dp8 rtol=1e-10 end +#= @testset "Continuous sensitivity tools" begin function f!(du, u, p, t) du[1] = p[1] + p[2] * u[1] @@ -485,7 +472,7 @@ end u0, p) @test du0≈Zdu0 atol=1e-4 @test dp≈Zdp atol=1e-4 - Zdu0, Zdp = Zygote.gradient((u0, p) -> loss(u0, p, sensealg = BacksolveAdjoint()), + Zdu0, Zdp = Zygote.gradient((u0, p) -> loss(u0, p, sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP())), u0, p) @test du0≈Zdu0 atol=1e-4 @@ -577,3 +564,4 @@ end @test dp≈Zdp atol=1e-4 end end +=# \ No newline at end of file From 330bf66d96e57e568760178093301795d6acea8e Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 24 Feb 2023 07:56:14 -0500 Subject: [PATCH 2/6] Fix termination check --- src/backsolve_adjoint.jl | 6 +++--- src/interpolating_adjoint.jl | 6 +++--- src/quadrature_adjoint.jl | 6 +++--- test/steady_state.jl | 4 +--- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/backsolve_adjoint.jl b/src/backsolve_adjoint.jl index e0c12564d..ccb0e2b12 100644 --- a/src/backsolve_adjoint.jl +++ b/src/backsolve_adjoint.jl @@ -147,7 +147,7 @@ end # check if solution was terminated, then use reduced time span terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end @@ -253,7 +253,7 @@ end # check if solution was terminated, then use reduced time span terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end @@ -370,7 +370,7 @@ end # check if solution was terminated, then use reduced time span terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end diff --git a/src/interpolating_adjoint.jl b/src/interpolating_adjoint.jl index 162f2e0bf..56348b92a 100644 --- a/src/interpolating_adjoint.jl +++ b/src/interpolating_adjoint.jl @@ -290,7 +290,7 @@ end # check if solution was terminated, then use reduced time span terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end @@ -402,7 +402,7 @@ end # check if solution was terminated, then use reduced time span terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end @@ -542,7 +542,7 @@ end # check if solution was terminated, then use reduced time span terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index cd33b8230..c3252192e 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -107,7 +107,7 @@ end terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end @@ -359,12 +359,12 @@ function _adjoint_sensitivities(sol, sensealg::QuadratureAdjoint, alg; t = nothi end # correction for end interval. - if t[end] != sol.prob.tspan[2] && sol.retcode !== :Terminated + if t[end] != sol.prob.tspan[2] && sol.retcode !== ReturnCode.Terminated res .+= quadgk(integrand, t[end], sol.prob.tspan[end], atol = abstol, rtol = reltol)[1] end - if sol.retcode === :Terminated + if sol.retcode === ReturnCode.Terminated integrand = update_integrand_and_dgrad(res, sensealg, callback, integrand, adj_prob, sol, dgdu_discrete, dgdp_discrete, dλ, dgrad, t[end], diff --git a/test/steady_state.jl b/test/steady_state.jl index a7f73faa6..c93eae0a1 100644 --- a/test/steady_state.jl +++ b/test/steady_state.jl @@ -423,7 +423,6 @@ end @test dp1≈dp8 rtol=1e-10 end -#= @testset "Continuous sensitivity tools" begin function f!(du, u, p, t) du[1] = p[1] + p[2] * u[1] @@ -563,5 +562,4 @@ end @test du0≈Zdu0 atol=1e-4 @test dp≈Zdp atol=1e-4 end -end -=# \ No newline at end of file +end \ No newline at end of file From 3d915220271cd09f871be1732c0929239c4cac22 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 24 Feb 2023 07:56:48 -0500 Subject: [PATCH 3/6] Update test/steady_state.jl --- test/steady_state.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/steady_state.jl b/test/steady_state.jl index c93eae0a1..326905a98 100644 --- a/test/steady_state.jl +++ b/test/steady_state.jl @@ -471,7 +471,7 @@ end u0, p) @test du0≈Zdu0 atol=1e-4 @test dp≈Zdp atol=1e-4 - Zdu0, Zdp = Zygote.gradient((u0, p) -> loss(u0, p, sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP())), + Zdu0, Zdp = Zygote.gradient((u0, p) -> loss(u0, p, sensealg = BacksolveAdjoint()), u0, p) @test du0≈Zdu0 atol=1e-4 From f81d3bc301d9f19040f41af95b5a2945893bd886 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 24 Feb 2023 18:55:30 -0500 Subject: [PATCH 4/6] Fix Flux types --- test/complex_no_u.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/complex_no_u.jl b/test/complex_no_u.jl index 7c589b9b7..297586084 100644 --- a/test/complex_no_u.jl +++ b/test/complex_no_u.jl @@ -1,5 +1,5 @@ using OrdinaryDiffEq, SciMLSensitivity, LinearAlgebra, Optimization, OptimizationFlux, Flux -nn = Chain(Dense(1, 16), Dense(16, 16, tanh), Dense(16, 2)) +nn = Chain(Dense(1, 16), Dense(16, 16, tanh), Dense(16, 2)) |> f64 initial, re = Flux.destructure(nn) function ode2!(u, p, t) From 83dc0ba62c2ee9b35ca2823fb84b47757589d877 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 25 Feb 2023 03:22:40 -0500 Subject: [PATCH 5/6] Fix componentarray importing in docs --- docs/src/examples/dae/physical_constraints.md | 16 ++++++++-------- docs/src/examples/ode/exogenous_input.md | 6 +++--- docs/src/tutorials/training_tips/local_minima.md | 4 ++-- docs/src/tutorials/training_tips/multiple_nn.md | 14 +++++++------- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/docs/src/examples/dae/physical_constraints.md b/docs/src/examples/dae/physical_constraints.md index ecf20fe65..bf2cf5d77 100644 --- a/docs/src/examples/dae/physical_constraints.md +++ b/docs/src/examples/dae/physical_constraints.md @@ -9,7 +9,7 @@ zeros, then we have a constraint defined by the right-hand side. Using terms must add to one. An example of this is as follows: ```@example dae -using Lux, DiffEqFlux, Optimization, OptimizationNLopt, DifferentialEquations, Plots +using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt, DifferentialEquations, Plots using Random rng = Random.default_rng() @@ -42,7 +42,7 @@ pinit, st = Lux.setup(rng, nn_dudt2) model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1], tspan, M, Rodas5(autodiff = false), saveat = 0.1) -model_stiff_ndae(u₀, Lux.ComponentArray(pinit), st) +model_stiff_ndae(u₀, ComponentArray(pinit), st) function predict_stiff_ndae(p) return model_stiff_ndae(u₀, p, st)[1] @@ -59,11 +59,11 @@ end # return false # end -l1 = first(loss_stiff_ndae(Lux.ComponentArray(pinit))) +l1 = first(loss_stiff_ndae(ComponentArray(pinit))) adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype) -optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(pinit)) +optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit)) result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100) ``` @@ -72,7 +72,7 @@ result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100) ### Load Packages ```@example dae2 -using Lux, DiffEqFlux, Optimization, OptimizationNLopt, DifferentialEquations, Plots +using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt, DifferentialEquations, Plots using Random rng = Random.default_rng() @@ -142,7 +142,7 @@ pinit, st = Lux.setup(rng, nn_dudt2) model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1], tspan, M, Rodas5(autodiff = false), saveat = 0.1) -model_stiff_ndae(u₀, Lux.ComponentArray(pinit), st) +model_stiff_ndae(u₀, ComponentArray(pinit), st) ``` Because this is a stiff problem, we have manually imposed that sum constraint via @@ -176,7 +176,7 @@ function loss_stiff_ndae(p) return loss, pred end -l1 = first(loss_stiff_ndae(Lux.ComponentArray(pinit))) +l1 = first(loss_stiff_ndae(ComponentArray(pinit))) ``` Notice that we are feeding the **parameters** of `model_stiff_ndae` to the `loss_stiff_ndae` @@ -206,6 +206,6 @@ Finally, training with `Optimization.solve` by passing: *loss function*, *model ```@example dae2 adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype) -optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(pinit)) +optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit)) result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100) ``` diff --git a/docs/src/examples/ode/exogenous_input.md b/docs/src/examples/ode/exogenous_input.md index 9d5588497..c74be6413 100644 --- a/docs/src/examples/ode/exogenous_input.md +++ b/docs/src/examples/ode/exogenous_input.md @@ -40,8 +40,8 @@ In the following example, a discrete exogenous input signal `ex` is defined and used as an input into the neural network of a neural ODE system. ```@example exogenous -using DifferentialEquations, Lux, DiffEqFlux, Optimization, OptimizationPolyalgorithms, - OptimizationFlux, Plots, Random +using DifferentialEquations, Lux, ComponentArrays, DiffEqFlux, Optimization, + OptimizationPolyalgorithms, OptimizationFlux, Plots, Random rng = Random.default_rng() tspan = (0.1f0, Float32(10.0)) @@ -88,7 +88,7 @@ end adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) -optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p_model)) +optprob = Optimization.OptimizationProblem(optf, ComponentArray(p_model)) res0 = Optimization.solve(optprob, PolyOpt(), maxiters = 100) diff --git a/docs/src/tutorials/training_tips/local_minima.md b/docs/src/tutorials/training_tips/local_minima.md index 7e4d23d43..661f70712 100644 --- a/docs/src/tutorials/training_tips/local_minima.md +++ b/docs/src/tutorials/training_tips/local_minima.md @@ -16,7 +16,7 @@ before, except with one small twist: we wish to find the neural ODE that fits on `(0,5.0)`. Naively, we use the same training strategy as before: ```@example iterativefit -using DifferentialEquations, SciMLSensitivity, Optimization, OptimizationFlux +using DifferentialEquations, ComponentArrays, SciMLSensitivity, Optimization, OptimizationFlux using Lux, Plots, Random rng = Random.default_rng() @@ -38,7 +38,7 @@ dudt2 = Lux.Chain(ActivationFunction(x -> x .^ 3), Lux.Dense(16, 2)) pinit, st = Lux.setup(rng, dudt2) -pinit = Lux.ComponentArray(pinit) +pinit = ComponentArray(pinit) function neuralode_f(u, p, t) dudt2(u, p, st)[1] diff --git a/docs/src/tutorials/training_tips/multiple_nn.md b/docs/src/tutorials/training_tips/multiple_nn.md index cf3ecd08f..26d1c3418 100644 --- a/docs/src/tutorials/training_tips/multiple_nn.md +++ b/docs/src/tutorials/training_tips/multiple_nn.md @@ -7,7 +7,7 @@ this kind of study. The following is a fully working demo on the Fitzhugh-Nagumo ODE: ```@example -using Lux, DiffEqFlux, Optimization, OptimizationNLopt, DifferentialEquations, Random +using Lux, DiffEqFlux, ComponentArrays, Optimization, OptimizationNLopt, DifferentialEquations, Random rng = Random.default_rng() Random.seed!(rng, 1) @@ -38,13 +38,13 @@ NN_2 = Lux.Chain(Lux.Dense(3, 16, tanh), Lux.Dense(16, 1)) p2, st2 = Lux.setup(rng, NN_2) scaling_factor = 1.0f0 -p1 = Lux.ComponentArray(p1) -p2 = Lux.ComponentArray(p2) +p1 = ComponentArray(p1) +p2 = ComponentArray(p2) -p = Lux.ComponentArray{eltype(p1)}() -p = Lux.ComponentArray(p; p1) -p = Lux.ComponentArray(p; p2) -p = Lux.ComponentArray(p; scaling_factor) +p = ComponentArray{eltype(p1)}() +p = ComponentArray(p; p1) +p = ComponentArray(p; p2) +p = ComponentArray(p; scaling_factor) function dudt_(u, p, t) v, w = u From 56f07426a2c7579417374357b12282cc6104ac9b Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 25 Feb 2023 05:37:15 -0500 Subject: [PATCH 6/6] format --- docs/src/examples/dae/physical_constraints.md | 6 ++++-- docs/src/examples/ode/exogenous_input.md | 2 +- docs/src/tutorials/training_tips/local_minima.md | 3 ++- docs/src/tutorials/training_tips/multiple_nn.md | 3 ++- src/SciMLSensitivity.jl | 3 ++- src/quadrature_adjoint.jl | 6 +++--- test/partial_neural.jl | 2 +- test/steady_state.jl | 2 +- 8 files changed, 16 insertions(+), 11 deletions(-) diff --git a/docs/src/examples/dae/physical_constraints.md b/docs/src/examples/dae/physical_constraints.md index bf2cf5d77..feb18cafe 100644 --- a/docs/src/examples/dae/physical_constraints.md +++ b/docs/src/examples/dae/physical_constraints.md @@ -9,7 +9,8 @@ zeros, then we have a constraint defined by the right-hand side. Using terms must add to one. An example of this is as follows: ```@example dae -using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt, DifferentialEquations, Plots +using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt, + DifferentialEquations, Plots using Random rng = Random.default_rng() @@ -72,7 +73,8 @@ result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100) ### Load Packages ```@example dae2 -using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt, DifferentialEquations, Plots +using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt, + DifferentialEquations, Plots using Random rng = Random.default_rng() diff --git a/docs/src/examples/ode/exogenous_input.md b/docs/src/examples/ode/exogenous_input.md index c74be6413..0634a0e85 100644 --- a/docs/src/examples/ode/exogenous_input.md +++ b/docs/src/examples/ode/exogenous_input.md @@ -40,7 +40,7 @@ In the following example, a discrete exogenous input signal `ex` is defined and used as an input into the neural network of a neural ODE system. ```@example exogenous -using DifferentialEquations, Lux, ComponentArrays, DiffEqFlux, Optimization, +using DifferentialEquations, Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationPolyalgorithms, OptimizationFlux, Plots, Random rng = Random.default_rng() diff --git a/docs/src/tutorials/training_tips/local_minima.md b/docs/src/tutorials/training_tips/local_minima.md index 661f70712..b8311ce28 100644 --- a/docs/src/tutorials/training_tips/local_minima.md +++ b/docs/src/tutorials/training_tips/local_minima.md @@ -16,7 +16,8 @@ before, except with one small twist: we wish to find the neural ODE that fits on `(0,5.0)`. Naively, we use the same training strategy as before: ```@example iterativefit -using DifferentialEquations, ComponentArrays, SciMLSensitivity, Optimization, OptimizationFlux +using DifferentialEquations, ComponentArrays, SciMLSensitivity, Optimization, + OptimizationFlux using Lux, Plots, Random rng = Random.default_rng() diff --git a/docs/src/tutorials/training_tips/multiple_nn.md b/docs/src/tutorials/training_tips/multiple_nn.md index 26d1c3418..0b2e26b77 100644 --- a/docs/src/tutorials/training_tips/multiple_nn.md +++ b/docs/src/tutorials/training_tips/multiple_nn.md @@ -7,7 +7,8 @@ this kind of study. The following is a fully working demo on the Fitzhugh-Nagumo ODE: ```@example -using Lux, DiffEqFlux, ComponentArrays, Optimization, OptimizationNLopt, DifferentialEquations, Random +using Lux, DiffEqFlux, ComponentArrays, Optimization, OptimizationNLopt, + DifferentialEquations, Random rng = Random.default_rng() Random.seed!(rng, 1) diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 95543c135..33ff70c7a 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -34,7 +34,8 @@ import SciMLBase: unwrapped_f import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAlgorithm, AbstractForwardSensitivityAlgorithm, AbstractAdjointSensitivityAlgorithm, - AbstractSecondOrderSensitivityAlgorithm, AbstractShadowingSensitivityAlgorithm + AbstractSecondOrderSensitivityAlgorithm, + AbstractShadowingSensitivityAlgorithm include("hasbranching.jl") include("sensitivity_algorithms.jl") diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index c3252192e..45609579d 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -141,11 +141,11 @@ end original_mm = sol.prob.f.mass_matrix if original_mm === I || original_mm === (I, I) odefun = ODEFunction{ArrayInterface.ismutable(z0), true}(sense, - jac_prototype = adjoint_jac_prototype) + jac_prototype = adjoint_jac_prototype) else odefun = ODEFunction{ArrayInterface.ismutable(z0), true}(sense, - mass_matrix = sol.prob.f.mass_matrix', - jac_prototype = adjoint_jac_prototype) + mass_matrix = sol.prob.f.mass_matrix', + jac_prototype = adjoint_jac_prototype) end if RetCB return ODEProblem(odefun, z0, tspan, p, callback = cb), rcb diff --git a/test/partial_neural.jl b/test/partial_neural.jl index 9916eb5de..b2bb9a8bf 100644 --- a/test/partial_neural.jl +++ b/test/partial_neural.jl @@ -19,7 +19,7 @@ prob = ODEProblem(dudt2_, x, tspan, _p) solve(prob, Tsit5()) function predict_rd(θ) - Array(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:end], abstol = 1f-7, reltol = 1f-5)) + Array(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:end], abstol = 1.0f-7, reltol = 1.0f-5)) end loss_rd(p) = sum(abs2, x - 1 for x in predict_rd(p)) diff --git a/test/steady_state.jl b/test/steady_state.jl index 326905a98..3e0aeabf0 100644 --- a/test/steady_state.jl +++ b/test/steady_state.jl @@ -562,4 +562,4 @@ end @test du0≈Zdu0 atol=1e-4 @test dp≈Zdp atol=1e-4 end -end \ No newline at end of file +end