Skip to content

Commit

Permalink
chore: SciMLBase.NullParameters -> DiffEqBase.NullParameters
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed Jun 27, 2024
1 parent 3559012 commit a752804
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 24 deletions.
8 changes: 4 additions & 4 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
prob = sol.prob
u0 = state_values(prob)
p = parameter_values(prob)
if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, _ = canonicalize(Tunable(), p)
Expand Down Expand Up @@ -367,14 +367,14 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t;
numindvar = nothing, alg = nothing, isinplace = true,
isRODE = false, _W = nothing)
# f = unwrappedf
if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
tunables, repack = p, identity
else
tunables, repack, aliases = canonicalize(Tunable(), p)
end
if isinplace
if !isRODE
__p = p isa SciMLBase.NullParameters ? _p :
__p = p === DiffEqBase.NullParameters() ? _p :
SciMLStructures.replace(Tunable(), p, _p)
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t
du1 = (p !== nothing && p !== DiffEqBase.NullParameters()) ?
Expand All @@ -397,7 +397,7 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t;
# GradientTape doesn't handle NullParameters; hence _p isa zeros(...)
# Cannot define replace(Tunable(), ::NullParameters, ::Vector)
# because hasportion(Tunable(), NullParameters) == false
__p = p isa SciMLBase.NullParameters ? _p :
__p = p === DiffEqBase.NullParameters() ? _p :
SciMLStructures.replace(Tunable(), p, _p)
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t
vec(f(u, p, first(t)))
Expand Down
2 changes: 1 addition & 1 deletion src/backsolve_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ end

p = parameter_values(sol.prob)
u0 = state_values(sol.prob)
if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
tunables, repack = p, identity
else
tunables, repack, _ = canonicalize(Tunable(), p)
Expand Down
22 changes: 11 additions & 11 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function inplace_vjp(prob, u0, p, verbose, repack)

vjp = try
f = unwrapped_f(prob.f)
if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
ReverseDiff.GradientTape((copy(u0), [prob.tspan[1]])) do u, t
du1 = similar(u, size(u))
du1 .= 0
Expand Down Expand Up @@ -97,7 +97,7 @@ function automatic_sensealg_choice(
t = prob.tspan[1]
λ = zero(prob.u0)

if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
_dy, back = Zygote.pullback(y) do u
vec(f(u, p, t))
end
Expand All @@ -120,7 +120,7 @@ function automatic_sensealg_choice(

if vjp == false
vjp = try
if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
ReverseDiff.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0)
else
ReverseDiff.gradient(
Expand All @@ -145,7 +145,7 @@ function automatic_sensealg_choice(
t = prob.tspan[1]
λ = zero(prob.u0)

if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
_dy, back = Tracker.forward(y) do u
vec(f(u, p, t))
end
Expand Down Expand Up @@ -247,13 +247,13 @@ function DiffEqBase._concrete_solve_adjoint(
has_cb = false
end

if !(p === nothing || p isa SciMLBase.NullParameters)
if !(p === nothing || p === DiffEqBase.NullParameters())
if !isscimlstructure(p) && !isfunctor(p)
throw(SciMLStructuresCompatibilityError())
end
end

if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, aliases = canonicalize(Tunable(), p)
Expand All @@ -280,13 +280,13 @@ function DiffEqBase._concrete_solve_adjoint(
sensealg::Nothing, u0, p,
originator::SciMLBase.ADOriginator, args...;
verbose = true, kwargs...)
if !(p === nothing || p isa SciMLBase.NullParameters)
if !(p === nothing || p === DiffEqBase.NullParameters())
if !isscimlstructure(p) && !isfunctor(p)
throw(SciMLStructuresCompatibilityError())
end
end

if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, aliases = canonicalize(Tunable(), p)
Expand Down Expand Up @@ -376,7 +376,7 @@ function DiffEqBase._concrete_solve_adjoint(
throw(AdjointSensitivityParameterCompatibilityError())
end

if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, aliases = canonicalize(Tunable(), p)
Expand Down Expand Up @@ -1175,13 +1175,13 @@ function DiffEqBase._concrete_solve_adjoint(
throw(EnzymeTrackedRealError())
end

if !(p === nothing || p isa SciMLBase.NullParameters)
if !(p === nothing || p === DiffEqBase.NullParameters())
if !isscimlstructure(p)
throw(SciMLStructuresCompatibilityError())
end
end

if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
tunables, repack = p, identity
else
tunables, repack, _ = canonicalize(Tunable(), p)
Expand Down
6 changes: 3 additions & 3 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dg
_p = p
end

if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
tunables, repack = p, identity
else
tunables, repack, aliases = canonicalize(Tunable(), p)
Expand Down Expand Up @@ -660,7 +660,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,

prob = getprob(S)
_p = parameter_values(prob)
if _p === nothing || _p isa SciMLBase.NullParameters
if _p === nothing || _p === DiffEqBase.NullParameters()
tunables, repack = _p, identity
else
tunables, repack, _ = canonicalize(Tunable(), _p)
Expand Down Expand Up @@ -898,7 +898,7 @@ function _jacNoise!(λ, y, p, t, S::TS, isnoise::ZygoteVJP, dgrad, dλ,
prob = getprob(S)
p_ = parameter_values(prob)

if p_ === nothing || p_ isa SciMLBase.NullParameters
if p_ === nothing || p_ === DiffEqBase.NullParameters()
tunables, repack = p_, identity
else
tunables, repack, _ = canonicalize(Tunable(), p_)
Expand Down
4 changes: 2 additions & 2 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
f = sol.prob.f
isautojacvec = get_jacvec(sensealg)
# y is aliased
if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, _ = canonicalize(Tunable(), p)
Expand Down Expand Up @@ -546,7 +546,7 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,
throw(SciMLStructuresCompatibilityError())
end

if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, _ = canonicalize(Tunable(), p)
Expand Down
2 changes: 1 addition & 1 deletion src/interpolating_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ end
p = parameter_values(sol.prob)
u0 = state_values(sol.prob)

if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
tunables, repack = p, identity
else
tunables, repack, _ = canonicalize(Tunable(), p)
Expand Down
4 changes: 2 additions & 2 deletions src/sensitivity_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ function adjoint_sensitivities(sol, args...;
sensealg = InterpolatingAdjoint(),
verbose = true, kwargs...)
p = SymbolicIndexingInterface.parameter_values(sol)
if !(p === nothing || p isa SciMLBase.NullParameters)
if !(p === nothing || p === DiffEqBase.NullParameters())
if !isscimlstructure(p) && !isfunctor(p)
throw(SciMLStructuresCompatibilityError())
end
Expand Down Expand Up @@ -449,7 +449,7 @@ function _adjoint_sensitivities(sol, sensealg, alg;
save_everystep = false, save_start = false, saveat = eltype(state_values(sol, 1))[],
tstops = tstops, abstol = abstol, reltol = reltol, kwargs...)

if mtkp === nothing || mtkp isa SciMLBase.NullParameters
if mtkp === nothing || mtkp === DiffEqBase.NullParameters()
tunables, repack = mtkp, identity
else
tunables, _, _ = canonicalize(Tunable(), mtkp)
Expand Down

0 comments on commit a752804

Please sign in to comment.