Skip to content

Commit

Permalink
Update default ODE solver tests to use new OrdinaryDiffEq infrastructure
Browse files Browse the repository at this point in the history
Final piece of #1035
  • Loading branch information
ChrisRackauckas committed Jun 6, 2024
1 parent b02bf7d commit f28a2d7
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 22 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ jobs:
- Core5
- Core6
- Core7
- DiffEq
- SDE1
- SDE2
- SDE3
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ NLsolve = "4.5.1"
NonlinearSolve = "3.0.1"
Optimization = "3.19.3"
OptimizationOptimisers = "0.1.6"
OrdinaryDiffEq = "6.68.1"
OrdinaryDiffEq = "6.81.1"
Parameters = "0.12"
Pkg = "1.10"
PreallocationTools = "0.4.4"
Expand Down
13 changes: 11 additions & 2 deletions test/diffeq/default_alg_diff.jl → test/default_alg_diff.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ComponentArrays, DifferentialEquations, Lux, Random, SciMLSensitivity, Zygote
using ComponentArrays, OrdinaryDiffEq, Lux, Random, SciMLSensitivity, Zygote

function f(du, u, p, t)
du .= first(nn(u, p, st))
Expand All @@ -13,7 +13,16 @@ r = rand(Float32, 8, 64)

function f2(x)
prob = ODEProblem(f, r, (0.0f0, 1.0f0), x)
sol = solve(prob; sensealg = InterpolatingAdjoint(; autodiff = true, autojacvec = true))
sol = solve(prob, OrdinaryDiffEq.DefaultODEAlgorithm())
sum(last(sol.u))
end

f2(ps)
Zygote.gradient(f2, ps)

function f2(x)
prob = ODEProblem(f, r, (0.0f0, 1.0f0), x)
sol = solve(prob)
sum(last(sol.u))
end

Expand Down
5 changes: 0 additions & 5 deletions test/diffeq/Project.toml

This file was deleted.

14 changes: 1 addition & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@ function activate_gpu_env()
Pkg.instantiate()
end

function activate_diffeq_env()
Pkg.activate("diffeq")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Pkg.instantiate()
end

@time @testset "SciMLSensitivity" begin
if GROUP == "All" || GROUP == "Core1" || GROUP == "Downstream"
@testset "Core1" begin
Expand Down Expand Up @@ -49,6 +43,7 @@ end

if GROUP == "All" || GROUP == "Core3" || GROUP == "Downstream"
@testset "Core 3" begin
@time @safetestset "Default DiffEq Alg" include("default_alg_diff.jl")
@time @safetestset "Adjoint Sensitivity" include("adjoint.jl")
@time @safetestset "automatic sensealg choice" include("automatic_sensealg_choice.jl")
end
Expand Down Expand Up @@ -150,13 +145,6 @@ end
end
end

if GROUP == "DiffEq"
@testset "DiffEq" begin
activate_diffeq_env()
@time @safetestset "Default DiffEq Alg" include("diffeq/default_alg_diff.jl")
end
end

if GROUP == "GPU"
@testset "GPU" begin
activate_gpu_env()
Expand Down

0 comments on commit f28a2d7

Please sign in to comment.