Skip to content

Commit

Permalink
Merge branch 'master' into sarray
Browse files Browse the repository at this point in the history
  • Loading branch information
ba2tro committed Jul 11, 2022
2 parents bdde708 + ede98e0 commit a5c4574
Show file tree
Hide file tree
Showing 28 changed files with 1,145 additions and 719 deletions.
9 changes: 5 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLSensitivity"
uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
authors = ["Christopher Rackauckas <[email protected]>", "Yingbo Ma <[email protected]>"]
version = "7.0.2"
version = "7.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -44,7 +44,7 @@ ArrayInterfaceCore = "0.1.1"
ArrayInterfaceTracker = "0.1"
Cassette = "0.3.6"
ChainRulesCore = "0.10.7, 1"
DiffEqBase = "6.90"
DiffEqBase = "6.93"
DiffEqCallbacks = "2.17"
DiffEqNoiseProcess = "4.1.4, 5.0"
DiffEqOperators = "4.34"
Expand All @@ -63,7 +63,7 @@ RandomNumbers = "1.5.3"
RecursiveArrayTools = "2.4.2"
Reexport = "0.2, 1.0"
ReverseDiff = "1.9"
SciMLBase = "1.24"
SciMLBase = "1.42.3"
StochasticDiffEq = "6.20"
Tracker = "0.2"
Zygote = "0.6"
Expand All @@ -78,6 +78,7 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationFlux = "253f991c-a7b2-45f8-8852-8b9a9df78a86"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Expand All @@ -90,4 +91,4 @@ SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "SparseArrays"]
test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "NonlinearSolve", "SparseArrays"]
2 changes: 1 addition & 1 deletion docs/src/ad_examples/adjoint_continuous_functional.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ To get the adjoint sensitivities, we call:
```@example continuousadjoint
prob = ODEProblem(f,[1.0;1.0],(0.0,10.0),p)
sol = solve(prob,DP8())
res = adjoint_sensitivities(sol,Vern9(),dg_continuous=dg,g=g,abstol=1e-8,reltol=1e-8)
res = adjoint_sensitivities(sol,Vern9(),dgdu_continuous=dg,g=g,abstol=1e-8,reltol=1e-8)
```

Notice that we can check this against autodifferentiation and numerical
Expand Down
2 changes: 1 addition & 1 deletion docs/src/ad_examples/direct_sensitivity.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ sensitivities, call:

```@example directsense
ts = 0:0.5:10
res = adjoint_sensitivities(sol,Vern9(),t=ts,dg_discrete=dg,abstol=1e-14,
res = adjoint_sensitivities(sol,Vern9(),t=ts,dgdu_discrete=dg,abstol=1e-14,
reltol=1e-14)
```

Expand Down
62 changes: 35 additions & 27 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
struct AdjointDiffCache{UF, PF, G, TJ, PJT, uType, JC, GC, PJC, JNC, PJNC, rateType, DG, DI,
struct AdjointDiffCache{UF, PF, G, TJ, PJT, uType, JC, GC, PJC, JNC, PJNC, rateType, DG1,
DG2, DI,
AI, FM}
uf::UF
pf::PF
Expand All @@ -12,7 +13,8 @@ struct AdjointDiffCache{UF, PF, G, TJ, PJT, uType, JC, GC, PJC, JNC, PJNC, rateT
jac_noise_config::JNC
paramjac_noise_config::PJNC
f_cache::rateType
dg::DG
dgdu::DG1
dgdp::DG2
diffvar_idxs::DI
algevar_idxs::AI
factorized_mass_matrix::FM
Expand All @@ -24,10 +26,11 @@ end
return (AdjointDiffCache, y)
"""
function adjointdiffcache(g::G, sensealg, discrete, sol, dg::DG, f; quad = false,
noiseterm = false, needs_jac = false) where {G, DG}
function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f;
quad = false,
noiseterm = false, needs_jac = false) where {G, DG1, DG2}
prob = sol.prob
if prob isa DiffEqBase.SteadyStateProblem
if prob isa Union{SteadyStateProblem, NonlinearProblem}
@unpack u0, p = prob
tspan = (nothing, nothing)
#elseif prob isa SDEProblem
Expand Down Expand Up @@ -71,10 +74,10 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dg::DG, f; quad = false
end

if !discrete
if dg !== nothing
if dgdu !== nothing
pg = nothing
pg_config = nothing
if dg isa Tuple && length(dg) == 2
if dgdp !== nothing
dg_val = (similar(u0, numindvar), similar(u0, numparams))
dg_val[1] .= false
dg_val[2] .= false
Expand All @@ -83,14 +86,15 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dg::DG, f; quad = false
dg_val .= false
end
else
if !(prob isa RODEProblem)
pg = UGradientWrapper(g, tspan[2], p)
else
pg = RODEUGradientWrapper(g, tspan[2], p, last(sol.W))
end
pg_config = build_grad_config(sensealg, pg, u0, p)
dg_val = similar(u0, numindvar) # number of funcs size
dg_val .= false
pgpu = UGradientWrapper(g, tspan[2], p)
pgpu_config = build_grad_config(sensealg, pgpu, u0, p)
pgpp = ParamGradientWrapper(g, tspan[2], u0)
pgpp_config = build_grad_config(sensealg, pgpp, p, p)
pg = (pgpu, pgpp)
pg_config = (pgpu_config, pgpp_config)
dg_val = (similar(u0, numindvar), similar(u0, numparams))
dg_val[1] .= false
dg_val[2] .= false
end
else
dg_val = nothing
Expand Down Expand Up @@ -119,7 +123,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dg::DG, f; quad = false
end
end

if prob isa DiffEqBase.SteadyStateProblem
if prob isa Union{SteadyStateProblem, NonlinearProblem}
y = copy(sol.u)
else
y = copy(sol.u[end])
Expand All @@ -134,7 +138,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dg::DG, f; quad = false
@assert sensealg.autojacvec !== nothing

if sensealg.autojacvec isa ReverseDiffVJP
if prob isa DiffEqBase.SteadyStateProblem
if prob isa Union{SteadyStateProblem, NonlinearProblem}
if DiffEqBase.isinplace(prob)
tape = ReverseDiff.GradientTape((y, _p)) do u, p
du1 = p !== nothing && p !== DiffEqBase.NullParameters() ?
Expand Down Expand Up @@ -354,7 +358,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dg::DG, f; quad = false
adjoint_cache = AdjointDiffCache(uf, pf, pg, J, pJ, dg_val,
jac_config, pg_config, paramjac_config,
jac_noise_config, paramjac_noise_config,
f_cache, dg, diffvar_idxs, algevar_idxs,
f_cache, dgdu, dgdp, diffvar_idxs, algevar_idxs,
factorized_mass_matrix, issemiexplicitdae)

return adjoint_cache, y
Expand All @@ -365,7 +369,8 @@ function getprob(S::SensitivityFunction)
end
inplace_sensitivity(S::SensitivityFunction) = isinplace(getprob(S))

struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, gType,
struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, dg1Type,
dg2Type,
cacheType}
isq::Bool
λ::λType
Expand All @@ -375,11 +380,12 @@ struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, gT
idx::Int
F::FMType
sensealg::AlgType
g::gType
dgdu::dg1Type
dgdp::dg2Type
diffcache::cacheType
end

function ReverseLossCallback(sensefun, λ, t, g, cur_time)
function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time)
@unpack sensealg, y = sensefun
isq = (sensealg isa QuadratureAdjoint)

Expand All @@ -388,11 +394,11 @@ function ReverseLossCallback(sensefun, λ, t, g, cur_time)
idx = length(prob.u0)

return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix,
sensealg, g, sensefun.diffcache)
sensealg, dgdu, dgdp, sensefun.diffcache)
end

function (f::ReverseLossCallback)(integrator)
@unpack isq, λ, t, y, cur_time, idx, F, sensealg, g = f
@unpack isq, λ, t, y, cur_time, idx, F, sensealg, dgdu, dgdp = f
@unpack diffvar_idxs, algevar_idxs, issemiexplicitdae, J, uf, f_cache, jac_config = f.diffcache

p, u = integrator.p, integrator.u
Expand Down Expand Up @@ -438,7 +444,8 @@ function (f::ReverseLossCallback)(integrator)
end

# handle discrete loss contributions
function generate_callbacks(sensefun, dg, λ, t, t0, callback, init_cb, terminated = false)
function generate_callbacks(sensefun, dgdu, dgdp, λ, t, t0, callback, init_cb,
terminated = false)
if sensefun isa NILSASSensitivityFunction
@unpack sensealg = sensefun.S
else
Expand All @@ -451,13 +458,14 @@ function generate_callbacks(sensefun, dg, λ, t, t0, callback, init_cb, terminat
cur_time = Ref(length(t))
end

reverse_cbs = setup_reverse_callbacks(callback, sensealg, dg, cur_time, terminated)
reverse_cbs = setup_reverse_callbacks(callback, sensealg, dgdu, dgdp, cur_time,
terminated)
init_cb || return reverse_cbs, nothing

# callbacks can lead to non-unique time points
_t, duplicate_iterator_times = separate_nonunique(t)

rlcb = ReverseLossCallback(sensefun, λ, t, dg, cur_time)
rlcb = ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time)

if eltype(_t) !== typeof(t0)
_t = convert.(typeof(t0), _t)
Expand All @@ -467,7 +475,7 @@ function generate_callbacks(sensefun, dg, λ, t, t0, callback, init_cb, terminat
# handle duplicates (currently only for double occurances)
if duplicate_iterator_times !== nothing
# use same ref for cur_time to cope with concrete_solve
cbrev_dupl_affect = ReverseLossCallback(sensefun, λ, t, dg, cur_time)
cbrev_dupl_affect = ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time)
cb_dupl = PresetTimeCallback(duplicate_iterator_times[1], cbrev_dupl_affect)
return CallbackSet(cb, reverse_cbs, cb_dupl), duplicate_iterator_times
else
Expand Down
Loading

0 comments on commit a5c4574

Please sign in to comment.