Skip to content

Commit

Permalink
clean up implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-1Bhatt committed Jul 6, 2022
1 parent 966f3b7 commit 2df07ee
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 224 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand Down
1 change: 0 additions & 1 deletion src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import ZygoteRules, Zygote, ReverseDiff
import ArrayInterfaceCore, ArrayInterfaceTracker
import Enzyme
import GPUArraysCore
import StaticArrays

using Cassette, DiffRules
using Core: CodeInfo, SlotNumber, SSAValue, ReturnNode, GotoIfNot
Expand Down
18 changes: 11 additions & 7 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,13 +401,15 @@ function (f::ReverseLossCallback)(integrator)
copyto!(y, integrator.u[(end - idx + 1):end])
end

if u isa StaticArrays.SArray
gᵤ = isq ? λ : @view(λ[1:idx])
gᵤ = g(gᵤ, y, p, t[cur_time[]], cur_time[])
else
# if u isa StaticArrays.SArray
if ArrayInterfaceCore.ismutable(u)
# Warning: alias here! Be careful with λ
gᵤ = isq ? λ : @view(λ[1:idx])
g(gᵤ, y, p, t[cur_time[]], cur_time[])
else
@assert sensealg isa QuadratureAdjoint
gᵤ = isq ? λ : @view(λ[1:idx])
gᵤ = g(gᵤ, y, p, t[cur_time[]], cur_time[])
end

if issemiexplicitdae
Expand All @@ -425,10 +427,12 @@ function (f::ReverseLossCallback)(integrator)
F !== I && F !== (I, I) && ldiv!(F, Δλd)
end

if u isa StaticArrays.SArray
integrator.u += Δλd
else
# if u isa StaticArrays.SArray
if ArrayInterfaceCore.ismutable(u)
u[diffvar_idxs] .+= Δλd
else
@assert sensealg isa QuadratureAdjoint
integrator.u += Δλd
end
u_modified!(integrator, true)
cur_time[] -= 1
Expand Down
236 changes: 26 additions & 210 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,11 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg,
x = vec(Δ[1])
_out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs]))
elseif _save_idxs isa Colon
vec(_out) .= adapt(outtype, vec(Δ[1]))
if ArrayInterfaceCore.ismutable(u)
vec(_out) .= adapt(outtype, vec(Δ[1]))
else
_out = adapt(outtype, vec(Δ[1]))
end
else
vec(@view(_out[_save_idxs])) .= adapt(outtype,
vec(Δ[1])[_save_idxs])
Expand All @@ -269,7 +273,11 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg,
x = vec(Δ)
_out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs]))
elseif _save_idxs isa Colon
vec(_out) .= adapt(outtype, vec(Δ))
if ArrayInterfaceCore.ismutable(u)
vec(_out) .= adapt(outtype, vec(Δ))
else
_out = adapt(outtype, vec(Δ))
end
else
x = vec(Δ)
vec(@view(_out[_save_idxs])) .= adapt(outtype, @view(x[_save_idxs]))
Expand All @@ -283,7 +291,11 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg,
if typeof(_save_idxs) <: Number
_out[_save_idxs] = @view(x[_save_idxs])
elseif _save_idxs isa Colon
vec(_out) .= vec(x)
if ArrayInterfaceCore.ismutable(u)
vec(_out) .= vec(x)
else
_out = vec(x)
end
else
vec(@view(_out[_save_idxs])) .= vec(@view(x[_save_idxs]))
end
Expand All @@ -293,9 +305,15 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg,
reshape(Δ, prod(size(Δ)[1:(end - 1)]),
size(Δ)[end])[_save_idxs, i])
elseif _save_idxs isa Colon
vec(_out) .= vec(adapt(outtype,
reshape(Δ, prod(size(Δ)[1:(end - 1)]),
size(Δ)[end])[:, i]))
if ArrayInterfaceCore.ismutable(u)
vec(_out) .= vec(adapt(outtype,
reshape(Δ, prod(size(Δ)[1:(end - 1)]),
size(Δ)[end])[:, i]))
else
_out = vec(adapt(outtype,
reshape(Δ, prod(size(Δ)[1:(end - 1)]),
size(Δ)[end])[:, i]))
end
else
vec(@view(_out[_save_idxs])) .= vec(adapt(outtype,
reshape(Δ,
Expand All @@ -305,211 +323,9 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg,
end
end
end
end

if haskey(kwargs_adj, :callback_adj)
cb2 = CallbackSet(cb, kwargs[:callback_adj])
else
cb2 = cb
end

du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, dg_discrete = df,
sensealg = sensealg,
callback = cb2,
kwargs_adj...)

du0 = reshape(du0, size(u0))
dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing :
reshape(dp', size(p))

if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
end
out, adjoint_sensitivity_backpass
end

function DiffEqBase._concrete_solve_adjoint(prob, alg,
sensealg::AbstractAdjointSensitivityAlgorithm,
u0::StaticArrays.SVector, p, originator::SciMLBase.ADOriginator,
args...; save_start = true, save_end = true,
saveat = eltype(prob.tspan)[],
save_idxs = nothing,
kwargs...)
if !(typeof(p) <: Union{Nothing, SciMLBase.NullParameters, AbstractArray}) ||
(p isa AbstractArray && !Base.isconcretetype(eltype(p)))
throw(AdjointSensitivityParameterCompatibilityError())
end

# Remove saveat, etc. from kwargs since it's handled separately
# and letting it jump back in there can break the adjoint
kwargs_prob = NamedTuple(filter(x -> x[1] != :saveat && x[1] != :save_start &&
x[1] != :save_end && x[1] != :save_idxs,
prob.kwargs))

if haskey(kwargs, :callback)
cb = track_callbacks(CallbackSet(kwargs[:callback]), prob.tspan[1], prob.u0, prob.p,
sensealg)
_prob = remake(prob; u0 = u0, p = p, kwargs = merge(kwargs_prob, (; callback = cb)))
else
cb = nothing
_prob = remake(prob; u0 = u0, p = p, kwargs = kwargs_prob)
end

# Remove callbacks, saveat, etc. from kwargs since it's handled separately
kwargs_fwd = NamedTuple{Base.diff_names(Base._nt_names(values(kwargs)), (:callback,))}(values(kwargs))

# Capture the callback_adj for the reverse pass and remove both callbacks
kwargs_adj = NamedTuple{
Base.diff_names(Base._nt_names(values(kwargs)),
(:callback_adj, :callback))}(values(kwargs))
isq = sensealg isa QuadratureAdjoint
if typeof(sensealg) <: BacksolveAdjoint
sol = solve(_prob, alg, args...; save_noise = true,
save_start = save_start, save_end = save_end,
saveat = saveat, kwargs_fwd...)
elseif ischeckpointing(sensealg)
sol = solve(_prob, alg, args...; save_noise = true,
save_start = true, save_end = true,
saveat = saveat, kwargs_fwd...)
else
sol = solve(_prob, alg, args...; save_noise = true, save_start = true,
save_end = true, kwargs_fwd...)
end

# Force `save_start` and `save_end` in the forward pass This forces the
# solver to do the backsolve all the way back to `u0` Since the start aliases
# `_prob.u0`, this doesn't actually use more memory But it cleans up the
# implementation and makes `save_start` and `save_end` arg safe.
if typeof(sensealg) <: BacksolveAdjoint
# Saving behavior unchanged
ts = sol.t
only_end = length(ts) == 1 && ts[1] == _prob.tspan[2]
out = DiffEqBase.sensitivity_solution(sol, sol.u, ts)
elseif saveat isa Number
if _prob.tspan[2] > _prob.tspan[1]
ts = _prob.tspan[1]:convert(typeof(_prob.tspan[2]), abs(saveat)):_prob.tspan[2]
else
ts = _prob.tspan[2]:convert(typeof(_prob.tspan[2]), abs(saveat)):_prob.tspan[1]
end
# if _prob.tspan[2]-_prob.tspan[1] is not a multiple of saveat, one looses the last ts value
sol.t[end] !== ts[end] && (ts = fix_endpoints(sensealg, sol, ts))
if cb === nothing
_out = sol(ts)
else
_, duplicate_iterator_times = separate_nonunique(sol.t)
_out, ts = out_and_ts(ts, duplicate_iterator_times, sol)
end

out = if save_idxs === nothing
out = DiffEqBase.sensitivity_solution(sol, _out.u, ts)
else
out = DiffEqBase.sensitivity_solution(sol,
[_out[i][save_idxs]
for i in 1:length(_out)], ts)
end
only_end = length(ts) == 1 && ts[1] == _prob.tspan[2]
elseif isempty(saveat)
no_start = !save_start
no_end = !save_end
sol_idxs = 1:length(sol)
no_start && (sol_idxs = sol_idxs[2:end])
no_end && (sol_idxs = sol_idxs[1:(end - 1)])
only_end = length(sol_idxs) <= 1
_u = sol.u[sol_idxs]
u = save_idxs === nothing ? _u : [x[save_idxs] for x in _u]
ts = sol.t[sol_idxs]
out = DiffEqBase.sensitivity_solution(sol, u, ts)
else
_saveat = saveat isa Array ? sort(saveat) : saveat # for minibatching
if cb === nothing
_saveat = eltype(_saveat) <: typeof(prob.tspan[2]) ?
convert.(typeof(_prob.tspan[2]), _saveat) : _saveat
ts = _saveat
_out = sol(ts)
else
_ts, duplicate_iterator_times = separate_nonunique(sol.t)
_out, ts = out_and_ts(_saveat, duplicate_iterator_times, sol)
end

out = if save_idxs === nothing
out = DiffEqBase.sensitivity_solution(sol, _out.u, ts)
else
out = DiffEqBase.sensitivity_solution(sol,
[_out[i][save_idxs]
for i in 1:length(_out)], ts)
end
only_end = length(ts) == 1 && ts[1] == _prob.tspan[2]
end

_save_idxs = save_idxs === nothing ? Colon() : save_idxs

function adjoint_sensitivity_backpass(Δ)
function df(_out, u, p, t, i)
outtype = typeof(_out) <: SubArray ?
DiffEqBase.parameterless_type(_out.parent) :
DiffEqBase.parameterless_type(_out)
if only_end
eltype(Δ) <: NoTangent && return
if typeof(Δ) <: AbstractArray{<:AbstractArray} && length(Δ) == 1 && i == 1
# user did sol[end] on only_end
if typeof(_save_idxs) <: Number
x = vec(Δ[1])
_out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs]))
elseif _save_idxs isa Colon
_out = adapt(outtype, vec(Δ[1]))
else
vec(@view(_out[_save_idxs])) .= adapt(outtype,
vec(Δ[1])[_save_idxs])
end
else
Δ isa NoTangent && return
if typeof(_save_idxs) <: Number
x = vec(Δ)
_out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs]))
elseif _save_idxs isa Colon
_out = adapt(outtype, vec(Δ))
else
x = vec(Δ)
vec(@view(_out[_save_idxs])) .= adapt(outtype, @view(x[_save_idxs]))
end
end
else
!Base.isconcretetype(eltype(Δ)) &&
(Δ[i] isa NoTangent || eltype(Δ) <: NoTangent) && return
if typeof(Δ) <: AbstractArray{<:AbstractArray} || typeof(Δ) <: DESolution
x = Δ[i]
if typeof(_save_idxs) <: Number
_out[_save_idxs] = @view(x[_save_idxs])
elseif _save_idxs isa Colon
_out = vec(x)
else
vec(@view(_out[_save_idxs])) .= vec(@view(x[_save_idxs]))
end
else
if typeof(_save_idxs) <: Number
_out[_save_idxs] = adapt(outtype,
reshape(Δ, prod(size(Δ)[1:(end - 1)]),
size(Δ)[end])[_save_idxs, i])
elseif _save_idxs isa Colon
_out = vec(adapt(outtype,
reshape(Δ, prod(size(Δ)[1:(end - 1)]),
size(Δ)[end])[:, i]))######Required Assignment#################
else
vec(@view(_out[_save_idxs])) .= vec(adapt(outtype,
reshape(Δ,
prod(size(Δ)[1:(end - 1)]),
size(Δ)[end])[:,
i]))
end
end
if !(ArrayInterfaceCore.ismutable(u0))
return _out
end
return _out
end

if haskey(kwargs_adj, :callback_adj)
Expand Down
11 changes: 6 additions & 5 deletions src/quadrature_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ end
odefun = ODEFunction(sense, mass_matrix = sol.prob.f.mass_matrix',
jac_prototype = adjoint_jac_prototype)
end
return ODEProblem{!(z0 isa StaticArrays.SArray)}(odefun, z0, tspan, p, callback = cb)
return ODEProblem{ArrayInterfaceCore.ismutable(z0)}(odefun, z0, tspan, p, callback = cb)
end

struct AdjointSensitivityIntegrand{pType, uType, lType, rateType, S, AS, PF, PJC, PJT, DGP,
Expand Down Expand Up @@ -210,12 +210,13 @@ end
function (S::AdjointSensitivityIntegrand)(out, t)
@unpack y, λ, pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol, adj_sol = S
f = sol.prob.f
if eltype(sol.u) <: StaticArrays.SArray
y = sol(t)
λ = adj_sol(t)
else
# if eltype(sol.u) <: StaticArrays.SArray
if ArrayInterfaceCore.ismutable(eltype(sol.u))
sol(y, t)
adj_sol(λ, t)
else
y = sol(t)
λ = adj_sol(t)
end
isautojacvec = get_jacvec(sensealg)
# y is aliased
Expand Down

0 comments on commit 2df07ee

Please sign in to comment.