Skip to content

Commit

Permalink
chore: format
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed Jun 24, 2024
1 parent 18bcfa9 commit b2abf78
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 60 deletions.
12 changes: 8 additions & 4 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
else
pf = RODEParamJacobianWrapper(unwrappedf, _t, y, _W)
end
paramjac_config = build_param_jac_config(sensealg, pf, y, SciMLStructures.replace(Tunable(), p, tunables))
paramjac_config = build_param_jac_config(
sensealg, pf, y, SciMLStructures.replace(Tunable(), p, tunables))
else
if !isRODE
pf = ParamGradientWrapper(unwrappedf, _t, y)
Expand Down Expand Up @@ -315,7 +316,8 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
prob.noise_rate_prototype)
jac_noise_config = build_jac_config(sensealg, uf, u0)
end
paramjac_noise_config = build_param_jac_config(sensealg, pf, y, SciMLStructures.replace(Tunable(), p, tunables))
paramjac_noise_config = build_param_jac_config(
sensealg, pf, y, SciMLStructures.replace(Tunable(), p, tunables))
else
if StochasticDiffEq.is_diagonal_noise(prob)
pf = ParamGradientWrapper(unwrappedf, _t, y)
Expand Down Expand Up @@ -372,7 +374,8 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t;
end
if isinplace
if !isRODE
__p = p isa SciMLBase.NullParameters ? _p : SciMLStructures.replace(Tunable(), p, _p)
__p = p isa SciMLBase.NullParameters ? _p :
SciMLStructures.replace(Tunable(), p, _p)
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t
du1 = (p !== nothing && p !== DiffEqBase.NullParameters()) ?
similar(p, size(u)) : similar(u)
Expand All @@ -394,7 +397,8 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t;
# GradientTape doesn't handle NullParameters; hence _p isa zeros(...)
# Cannot define replace(Tunable(), ::NullParameters, ::Vector)
# because hasportion(Tunable(), NullParameters) == false
__p = p isa SciMLBase.NullParameters ? _p : SciMLStructures.replace(Tunable(), p, _p)
__p = p isa SciMLBase.NullParameters ? _p :
SciMLStructures.replace(Tunable(), p, _p)
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t
vec(f(u, p, first(t)))
end
Expand Down
4 changes: 2 additions & 2 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,9 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dg
end

if p === nothing || p isa SciMLBase.NullParameters
tunables, repack = p, identity
tunables, repack = p, identity
else
tunables, repack, aliases = canonicalize(Tunable(), p)
tunables, repack, aliases = canonicalize(Tunable(), p)
end

u0 = state_values(prob)
Expand Down
7 changes: 4 additions & 3 deletions src/forward_sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ function ODEForwardSensitivityProblem(f::F, args...; kwargs...) where {F}
end

function ODEForwardSensitivityProblem(prob::ODEProblem, alg; kwargs...)
ODEForwardSensitivityProblem(prob.f, state_values(prob), prob.tspan, parameter_values(prob), alg; kwargs...)
ODEForwardSensitivityProblem(
prob.f, state_values(prob), prob.tspan, parameter_values(prob), alg; kwargs...)
end

const FORWARD_SENSITIVITY_PARAMETER_COMPATIBILITY_MESSAGE = """
Expand Down Expand Up @@ -526,7 +527,6 @@ function extract_local_sensitivities(sol, asmatrix::Bool)
extract_local_sensitivities(sol, Val{asmatrix}())
end
function extract_local_sensitivities(sol, i::Integer, asmatrix::Val = Val(false))

_extract(sol, sol.prob.problem_type.sensealg, state_values(sol, i), asmatrix)
end
function extract_local_sensitivities(sol, i::Integer, asmatrix::Bool)
Expand Down Expand Up @@ -644,7 +644,8 @@ function SciMLBase.remake(
{uType, tType, isinplace, P, F, K}
_p = p === nothing ? parameter_values(prob) : p
_f = f === nothing ? prob.f.f : f
_u0 = u0 === nothing ? state_values(prob, 1:(prob.f.numindvar)) : u0[1:(prob.f.numindvar)]
_u0 = u0 === nothing ? state_values(prob, 1:(prob.f.numindvar)) :
u0[1:(prob.f.numindvar)]
_tspan = tspan === nothing ? prob.tspan : tspan
ODEForwardSensitivityProblem(_f, _u0,
_tspan, _p, prob.problem_type.sensealg;
Expand Down
24 changes: 12 additions & 12 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,9 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing)
if p === nothing || p isa DiffEqBase.NullParameters
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, _= canonicalize(Tunable(), p)
tunables, repack, _ = canonicalize(Tunable(), p)
else
tunables, repack = Functors.functor(p)

tunables, repack = Functors.functor(p)
end

numparams = length(tunables)
Expand All @@ -394,15 +393,17 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing)

if sensealg.autojacvec isa ReverseDiffVJP
tape = if DiffEqBase.isinplace(prob)
ReverseDiff.GradientTape((y, tunables, [tspan[2]])) do u, tunables, t
ReverseDiff.GradientTape((y, tunables, [tspan[2]])) do u, tunables, t
du1 = similar(tunables, size(u))
du1 .= false
unwrappedf(du1, u, SciMLStructures.replace(Tunable(), p, tunables), first(t))
unwrappedf(
du1, u, SciMLStructures.replace(Tunable(), p, tunables), first(t))
return vec(du1)
end
else
ReverseDiff.GradientTape((y, tunables, [tspan[2]])) do u, tunables, t
vec(unwrappedf(u, SciMLStructures.replace(Tunable(), p, tunables), first(t)))
vec(unwrappedf(
u, SciMLStructures.replace(Tunable(), p, tunables), first(t)))
end
end
if compile_tape(sensealg.autojacvec)
Expand Down Expand Up @@ -452,11 +453,11 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
isautojacvec = get_jacvec(sensealg)
# y is aliased
if p === nothing || p isa SciMLBase.NullParameters
tunables, repack = p, identity
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, _ = canonicalize(Tunable(), p)
tunables, repack, _ = canonicalize(Tunable(), p)
else
tunables, repack = Functors.functor(p)
tunables, repack = Functors.functor(p)
end

if !isautojacvec
Expand Down Expand Up @@ -540,18 +541,17 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,
corfunc_analytical = false,
callback = CallbackSet(),
kwargs...)

p = SymbolicIndexingInterface.parameter_values(sol)
if !isscimlstructure(p) && !isfunctor(p)
throw(SciMLStructuresCompatibilityError())
end

if p === nothing || p isa SciMLBase.NullParameters
tunables, repack = p, identity
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, _ = canonicalize(Tunable(), p)
elseif isfunctor(p)
tunables, repack = Functors.functor(p)
tunables, repack = Functors.functor(p)
else
throw(SciMLStructuresCompatibilityError())
end
Expand Down
72 changes: 36 additions & 36 deletions src/interpolating_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ function ODEInterpolatingAdjointSensitivityFunction(g, sensealg, discrete, sol,
dt = choose_dt((_sol.W.t[idx1] - _sol.W.t[idx1 + 1]), _sol.W.t, interval)

_ts = current_time(_sol)
cpsol = solve(remake(sol.prob, tspan = interval, u0 = sol(interval[1]),
cpsol = solve(
remake(sol.prob, tspan = interval, u0 = sol(interval[1]),
noise = forwardnoise),
sol.alg, save_noise = false; dt = dt, tstops = _ts[idx1:end],
tols...)
Expand Down Expand Up @@ -264,17 +265,17 @@ end

# g is either g(t,u,p) or discrete g(t,u,i)
@noinline function ODEAdjointProblem(sol, sensealg::InterpolatingAdjoint, alg,
t = nothing,
dgdu_discrete::DG1 = nothing,
dgdp_discrete::DG2 = nothing,
dgdu_continuous::DG3 = nothing,
dgdp_continuous::DG4 = nothing,
g::G = nothing,
::Val{RetCB} = Val(false);
checkpoints = current_time(sol),
callback = CallbackSet(),
reltol = nothing, abstol = nothing,
kwargs...) where {DG1, DG2, DG3, DG4, G, RetCB}
t = nothing,
dgdu_discrete::DG1 = nothing,
dgdp_discrete::DG2 = nothing,
dgdu_continuous::DG3 = nothing,
dgdp_continuous::DG4 = nothing,
g::G = nothing,
::Val{RetCB} = Val(false);
checkpoints = current_time(sol),
callback = CallbackSet(),
reltol = nothing, abstol = nothing,
kwargs...) where {DG1, DG2, DG3, DG4, G, RetCB}
dgdu_discrete === nothing && dgdu_continuous === nothing && g === nothing &&
error("Either `dgdu_discrete`, `dgdu_continuous`, or `g` must be specified.")
t !== nothing && dgdu_discrete === nothing && dgdp_discrete === nothing &&
Expand All @@ -287,12 +288,11 @@ end
u0 = state_values(sol.prob)

if p === nothing || p isa SciMLBase.NullParameters
tunables, repack = p, identity
tunables, repack = p, identity
else
tunables, repack, _ = canonicalize(Tunable(), p)
tunables, repack, _ = canonicalize(Tunable(), p)
end


## Force recompile mode until vjps are specialized to handle this!!!
f = if sol.prob.f isa ODEFunction &&
sol.prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper
Expand Down Expand Up @@ -394,17 +394,17 @@ end
end

@noinline function SDEAdjointProblem(sol, sensealg::InterpolatingAdjoint, alg,
t = nothing,
dgdu_discrete::DG1 = nothing,
dgdp_discrete::DG2 = nothing,
dgdu_continuous::DG3 = nothing,
dgdp_continuous::DG4 = nothing,
g::G = nothing;
checkpoints = current_time(sol),
callback = CallbackSet(),
reltol = nothing, abstol = nothing,
diffusion_jac = nothing, diffusion_paramjac = nothing,
kwargs...) where {DG1, DG2, DG3, DG4, G}
t = nothing,
dgdu_discrete::DG1 = nothing,
dgdp_discrete::DG2 = nothing,
dgdu_continuous::DG3 = nothing,
dgdp_continuous::DG4 = nothing,
g::G = nothing;
checkpoints = current_time(sol),
callback = CallbackSet(),
reltol = nothing, abstol = nothing,
diffusion_jac = nothing, diffusion_paramjac = nothing,
kwargs...) where {DG1, DG2, DG3, DG4, G}
dgdu_discrete === nothing && dgdu_continuous === nothing && g === nothing &&
error("Either `dgdu_discrete`, `dgdu_continuous`, or `g` must be specified.")
t !== nothing && dgdu_discrete === nothing && dgdp_discrete === nothing &&
Expand Down Expand Up @@ -536,16 +536,16 @@ end
end

@noinline function RODEAdjointProblem(sol, sensealg::InterpolatingAdjoint, alg,
t = nothing,
dgdu_discrete::DG1 = nothing,
dgdp_discrete::DG2 = nothing,
dgdu_continuous::DG3 = nothing,
dgdp_continuous::DG4 = nothing,
g::G = nothing;
checkpoints = current_time(sol),
callback = CallbackSet(),
reltol = nothing, abstol = nothing,
kwargs...) where {DG1, DG2, DG3, DG4, G}
t = nothing,
dgdu_discrete::DG1 = nothing,
dgdp_discrete::DG2 = nothing,
dgdu_continuous::DG3 = nothing,
dgdp_continuous::DG4 = nothing,
g::G = nothing;
checkpoints = current_time(sol),
callback = CallbackSet(),
reltol = nothing, abstol = nothing,
kwargs...) where {DG1, DG2, DG3, DG4, G}
dgdu_discrete === nothing && dgdu_continuous === nothing && g === nothing &&
error("Either `dgdu_discrete`, `dgdu_continuous`, or `g` must be specified.")
t !== nothing && dgdu_discrete === nothing && dgdp_discrete === nothing &&
Expand Down
5 changes: 2 additions & 3 deletions src/sensitivity_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,8 @@ res3 = Calculus.gradient(G,[1.5,1.0,3.0])
```
"""
function adjoint_sensitivities(sol, args...;
sensealg = InterpolatingAdjoint(),
verbose = true, kwargs...)

sensealg = InterpolatingAdjoint(),
verbose = true, kwargs...)
p = SymbolicIndexingInterface.parameter_values(sol)
if !(p === nothing || p isa SciMLBase.NullParameters)
if !isscimlstructure(p) && !isfunctor(p)
Expand Down

0 comments on commit b2abf78

Please sign in to comment.