Skip to content

Commit

Permalink
chore: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed Jun 23, 2024
1 parent 5463cec commit 006e2fb
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
21 changes: 13 additions & 8 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function inplace_vjp(prob, u0, p, verbose, repack)
Enzyme.Duplicated(copy(u0), zero(u0)), Enzyme.Duplicated(copy(p), zero(p)), Enzyme.Const(prob.tspan[1]))
true
catch e
if verbose || have_not_warned_vjp[]
if verbose && have_not_warned_vjp[]
@warn "Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.\n"
STACKTRACE_WITH_VJPWARN[] && showerror(stderr, e)
println()
Expand Down Expand Up @@ -244,7 +244,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractODEPro

if !(p === nothing || p isa SciMLBase.NullParameters)
if !isscimlstructure(p)
throw(SciMLStructuresCompatiblityError())
throw(SciMLStructuresCompatibilityError())
end
end

Expand Down Expand Up @@ -273,7 +273,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{

if !(p === nothing || p isa SciMLBase.NullParameters)
if !isscimlstructure(p) && !isfunctor(p)
throw(SciMLStructuresCompatiblityError())
throw(SciMLStructuresCompatibilityError())
end
end

Expand Down Expand Up @@ -366,8 +366,10 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractODEPro
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, aliases = canonicalize(Tunable(), p)
else
elseif sensealg isa Union{QuadratureAdjoint, GaussAdjoint}
tunables, repack = Functors.functor(p)
else
throw(SciMLStructuresCompatibilityError())
end
# Remove saveat, etc. from kwargs since it's handled separately
# and letting it jump back in there can break the adjoint
Expand Down Expand Up @@ -1130,7 +1132,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractDiscre

if !(p === nothing || p isa SciMLBase.NullParameters)
if !isscimlstructure(p)
throw(SciMLStructuresCompatiblityError())
throw(SciMLStructuresCompatibilityError())
end
end

Expand Down Expand Up @@ -1263,7 +1265,10 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractDiscre
elseif ybar[1] isa Array
return Array(ybar)
else
tmp = reduce(hcat, vec.(ybar.u))
tmp = vec(ybar.u[1])
for i in 2:length(ybar.u)
tmp = hcat(tmp, vec(ybar.u[i]))
end
return reshape(tmp, size(ybar.u[1])..., length(ybar.u))
end
u0bar, pbar = pullback(tmp)
Expand Down Expand Up @@ -1317,10 +1322,10 @@ const SCIMLSTRUCTURES_ERROR_MESSAGE = """
In particular, adjoint sensitivities only applies to `Tunable`.
"""

struct SciMLStructuresCompatiblityError <: Exception
struct SciMLStructuresCompatibilityError <: Exception
end

function Base.showerror(io::IO, e::SciMLStructuresCompatiblityError)
function Base.showerror(io::IO, e::SciMLStructuresCompatibilityError)
println(io, SCIMLSTRUCTURES_ERROR_MESSAGE)
end

Expand Down
4 changes: 3 additions & 1 deletion src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,10 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, _ = canonicalize(Tunable(), p)
else
elseif isfunctor(p)
tunables, repack = Functors.functor(p)
else
throw(SciMLStructuresCompatiblityError())

Check warning on line 556 in src/gauss_adjoint.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"Compatiblity" should be "Compatibility".
end
integrand = GaussIntegrand(sol, sensealg, checkpoints, dgdp_continuous)
integrand_values = IntegrandValuesSum(allocate_zeros(tunables))
Expand Down

0 comments on commit 006e2fb

Please sign in to comment.