diff --git a/docs/src/examples/dae/physical_constraints.md b/docs/src/examples/dae/physical_constraints.md index ecf20fe65..feb18cafe 100644 --- a/docs/src/examples/dae/physical_constraints.md +++ b/docs/src/examples/dae/physical_constraints.md @@ -9,7 +9,8 @@ zeros, then we have a constraint defined by the right-hand side. Using terms must add to one. An example of this is as follows: ```@example dae -using Lux, DiffEqFlux, Optimization, OptimizationNLopt, DifferentialEquations, Plots +using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt, + DifferentialEquations, Plots using Random rng = Random.default_rng() @@ -42,7 +43,7 @@ pinit, st = Lux.setup(rng, nn_dudt2) model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1], tspan, M, Rodas5(autodiff = false), saveat = 0.1) -model_stiff_ndae(u₀, Lux.ComponentArray(pinit), st) +model_stiff_ndae(u₀, ComponentArray(pinit), st) function predict_stiff_ndae(p) return model_stiff_ndae(u₀, p, st)[1] @@ -59,11 +60,11 @@ end # return false # end -l1 = first(loss_stiff_ndae(Lux.ComponentArray(pinit))) +l1 = first(loss_stiff_ndae(ComponentArray(pinit))) adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype) -optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(pinit)) +optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit)) result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100) ``` @@ -72,7 +73,8 @@ result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100) ### Load Packages ```@example dae2 -using Lux, DiffEqFlux, Optimization, OptimizationNLopt, DifferentialEquations, Plots +using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt, + DifferentialEquations, Plots using Random rng = Random.default_rng() @@ -142,7 +144,7 @@ pinit, st = Lux.setup(rng, nn_dudt2) model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1], tspan, M, Rodas5(autodiff = false), saveat = 0.1) -model_stiff_ndae(u₀, Lux.ComponentArray(pinit), st) +model_stiff_ndae(u₀, ComponentArray(pinit), st) ``` Because this is a stiff problem, we have manually imposed that sum constraint via @@ -176,7 +178,7 @@ function loss_stiff_ndae(p) return loss, pred end -l1 = first(loss_stiff_ndae(Lux.ComponentArray(pinit))) +l1 = first(loss_stiff_ndae(ComponentArray(pinit))) ``` Notice that we are feeding the **parameters** of `model_stiff_ndae` to the `loss_stiff_ndae` @@ -206,6 +208,6 @@ Finally, training with `Optimization.solve` by passing: *loss function*, *model ```@example dae2 adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype) -optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(pinit)) +optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit)) result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100) ``` diff --git a/docs/src/examples/ode/exogenous_input.md b/docs/src/examples/ode/exogenous_input.md index 9d5588497..0634a0e85 100644 --- a/docs/src/examples/ode/exogenous_input.md +++ b/docs/src/examples/ode/exogenous_input.md @@ -40,8 +40,8 @@ In the following example, a discrete exogenous input signal `ex` is defined and used as an input into the neural network of a neural ODE system. ```@example exogenous -using DifferentialEquations, Lux, DiffEqFlux, Optimization, OptimizationPolyalgorithms, - OptimizationFlux, Plots, Random +using DifferentialEquations, Lux, ComponentArrays, DiffEqFlux, Optimization, + OptimizationPolyalgorithms, OptimizationFlux, Plots, Random rng = Random.default_rng() tspan = (0.1f0, Float32(10.0)) @@ -88,7 +88,7 @@ end adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) -optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p_model)) +optprob = Optimization.OptimizationProblem(optf, ComponentArray(p_model)) res0 = Optimization.solve(optprob, PolyOpt(), maxiters = 100) diff --git a/docs/src/tutorials/training_tips/local_minima.md b/docs/src/tutorials/training_tips/local_minima.md index 7e4d23d43..b8311ce28 100644 --- a/docs/src/tutorials/training_tips/local_minima.md +++ b/docs/src/tutorials/training_tips/local_minima.md @@ -16,7 +16,8 @@ before, except with one small twist: we wish to find the neural ODE that fits on `(0,5.0)`. Naively, we use the same training strategy as before: ```@example iterativefit -using DifferentialEquations, SciMLSensitivity, Optimization, OptimizationFlux +using DifferentialEquations, ComponentArrays, SciMLSensitivity, Optimization, + OptimizationFlux using Lux, Plots, Random rng = Random.default_rng() @@ -38,7 +39,7 @@ dudt2 = Lux.Chain(ActivationFunction(x -> x .^ 3), Lux.Dense(16, 2)) pinit, st = Lux.setup(rng, dudt2) -pinit = Lux.ComponentArray(pinit) +pinit = ComponentArray(pinit) function neuralode_f(u, p, t) dudt2(u, p, st)[1] diff --git a/docs/src/tutorials/training_tips/multiple_nn.md b/docs/src/tutorials/training_tips/multiple_nn.md index cf3ecd08f..0b2e26b77 100644 --- a/docs/src/tutorials/training_tips/multiple_nn.md +++ b/docs/src/tutorials/training_tips/multiple_nn.md @@ -7,7 +7,8 @@ this kind of study. The following is a fully working demo on the Fitzhugh-Nagumo ODE: ```@example -using Lux, DiffEqFlux, Optimization, OptimizationNLopt, DifferentialEquations, Random +using Lux, DiffEqFlux, ComponentArrays, Optimization, OptimizationNLopt, + DifferentialEquations, Random rng = Random.default_rng() Random.seed!(rng, 1) @@ -38,13 +39,13 @@ NN_2 = Lux.Chain(Lux.Dense(3, 16, tanh), Lux.Dense(16, 1)) p2, st2 = Lux.setup(rng, NN_2) scaling_factor = 1.0f0 -p1 = Lux.ComponentArray(p1) -p2 = Lux.ComponentArray(p2) +p1 = ComponentArray(p1) +p2 = ComponentArray(p2) -p = Lux.ComponentArray{eltype(p1)}() -p = Lux.ComponentArray(p; p1) -p = Lux.ComponentArray(p; p2) -p = Lux.ComponentArray(p; scaling_factor) +p = ComponentArray{eltype(p1)}() +p = ComponentArray(p; p1) +p = ComponentArray(p; p2) +p = ComponentArray(p; scaling_factor) function dudt_(u, p, t) v, w = u diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 95543c135..33ff70c7a 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -34,7 +34,8 @@ import SciMLBase: unwrapped_f import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAlgorithm, AbstractForwardSensitivityAlgorithm, AbstractAdjointSensitivityAlgorithm, - AbstractSecondOrderSensitivityAlgorithm, AbstractShadowingSensitivityAlgorithm + AbstractSecondOrderSensitivityAlgorithm, + AbstractShadowingSensitivityAlgorithm include("hasbranching.jl") include("sensitivity_algorithms.jl") diff --git a/src/backsolve_adjoint.jl b/src/backsolve_adjoint.jl index e0c12564d..ccb0e2b12 100644 --- a/src/backsolve_adjoint.jl +++ b/src/backsolve_adjoint.jl @@ -147,7 +147,7 @@ end # check if solution was terminated, then use reduced time span terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end @@ -253,7 +253,7 @@ end # check if solution was terminated, then use reduced time span terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end @@ -370,7 +370,7 @@ end # check if solution was terminated, then use reduced time span terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end diff --git a/src/interpolating_adjoint.jl b/src/interpolating_adjoint.jl index 162f2e0bf..56348b92a 100644 --- a/src/interpolating_adjoint.jl +++ b/src/interpolating_adjoint.jl @@ -290,7 +290,7 @@ end # check if solution was terminated, then use reduced time span terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end @@ -402,7 +402,7 @@ end # check if solution was terminated, then use reduced time span terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end @@ -542,7 +542,7 @@ end # check if solution was terminated, then use reduced time span terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index cd33b8230..45609579d 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -107,7 +107,7 @@ end terminated = false if hasfield(typeof(sol), :retcode) - if sol.retcode == :Terminated + if sol.retcode == ReturnCode.Terminated tspan = (tspan[1], sol.t[end]) terminated = true end @@ -141,11 +141,11 @@ end original_mm = sol.prob.f.mass_matrix if original_mm === I || original_mm === (I, I) odefun = ODEFunction{ArrayInterface.ismutable(z0), true}(sense, - jac_prototype = adjoint_jac_prototype) + jac_prototype = adjoint_jac_prototype) else odefun = ODEFunction{ArrayInterface.ismutable(z0), true}(sense, - mass_matrix = sol.prob.f.mass_matrix', - jac_prototype = adjoint_jac_prototype) + mass_matrix = sol.prob.f.mass_matrix', + jac_prototype = adjoint_jac_prototype) end if RetCB return ODEProblem(odefun, z0, tspan, p, callback = cb), rcb @@ -359,12 +359,12 @@ function _adjoint_sensitivities(sol, sensealg::QuadratureAdjoint, alg; t = nothi end # correction for end interval. - if t[end] != sol.prob.tspan[2] && sol.retcode !== :Terminated + if t[end] != sol.prob.tspan[2] && sol.retcode !== ReturnCode.Terminated res .+= quadgk(integrand, t[end], sol.prob.tspan[end], atol = abstol, rtol = reltol)[1] end - if sol.retcode === :Terminated + if sol.retcode === ReturnCode.Terminated integrand = update_integrand_and_dgrad(res, sensealg, callback, integrand, adj_prob, sol, dgdu_discrete, dgdp_discrete, dλ, dgrad, t[end], diff --git a/test/complex_no_u.jl b/test/complex_no_u.jl index 7c589b9b7..297586084 100644 --- a/test/complex_no_u.jl +++ b/test/complex_no_u.jl @@ -1,5 +1,5 @@ using OrdinaryDiffEq, SciMLSensitivity, LinearAlgebra, Optimization, OptimizationFlux, Flux -nn = Chain(Dense(1, 16), Dense(16, 16, tanh), Dense(16, 2)) +nn = Chain(Dense(1, 16), Dense(16, 16, tanh), Dense(16, 2)) |> f64 initial, re = Flux.destructure(nn) function ode2!(u, p, t) diff --git a/test/partial_neural.jl b/test/partial_neural.jl index 9de382bc9..b2bb9a8bf 100644 --- a/test/partial_neural.jl +++ b/test/partial_neural.jl @@ -19,9 +19,9 @@ prob = ODEProblem(dudt2_, x, tspan, _p) solve(prob, Tsit5()) function predict_rd(θ) - Array(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:end], abstol = 1e-7, reltol = 1e-5, - sensealg = TrackerAdjoint())) + Array(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:end], abstol = 1.0f-7, reltol = 1.0f-5)) end + loss_rd(p) = sum(abs2, x - 1 for x in predict_rd(p)) l = loss_rd(θ) diff --git a/test/steady_state.jl b/test/steady_state.jl index 2a3012df4..3e0aeabf0 100644 --- a/test/steady_state.jl +++ b/test/steady_state.jl @@ -73,14 +73,7 @@ Random.seed!(12345) @info "Calculate adjoint sensitivities from autodiff & numerical diff" function G(p) tmp_prob = remake(prob, u0 = convert.(eltype(p), prob.u0), p = p) - sol = solve(tmp_prob, - SSRootfind(nlsolve = (f!, u0, abstol) -> (res = NLsolve.nlsolve(f!, - u0, - autodiff = :forward, - method = :newton, - iterations = Int(1e6), - ftol = 1e-14); - res.zero))) + sol = solve(tmp_prob, DynamicSS(Rodas5())) A = convert(Array, sol) g(A, p, nothing) end @@ -259,14 +252,7 @@ Random.seed!(12345) @testset "for u0: (should be zero, steady state does not depend on initial condition)" begin res5 = ForwardDiff.gradient(prob.u0) do u0 tmp_prob = remake(prob, u0 = u0) - sol = solve(tmp_prob, - SSRootfind(nlsolve = (f!, u0, abstol) -> (res = NLsolve.nlsolve(f!, - u0, - autodiff = :forward, - method = :newton, - iterations = Int(1e6), - ftol = 1e-14); - res.zero))) + sol = solve(tmp_prob, DynamicSS(Rodas5())) A = convert(Array, sol) g(A, p, nothing) end