Skip to content

Commit

Permalink
chore: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed Jun 24, 2024
1 parent 4b987e0 commit 18bcfa9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
13 changes: 10 additions & 3 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,23 +243,27 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractODEPro
end

if !(p === nothing || p isa SciMLBase.NullParameters)
if !isscimlstructure(p)
if !isscimlstructure(p) && !isfunctor(p)
throw(SciMLStructuresCompatibilityError())
end
end

if p === nothing || p isa SciMLBase.NullParameters
tunables, repack = p, identity
else
elseif isscimlstructure(p)
tunables, repack, aliases = canonicalize(Tunable(), p)
elseif isfunctor(p)
tunables, repack = Functors.functor(p)
else
throw(SciMLStructuresCompatibilityError())
end

default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack)
if has_cb && default_sensealg isa AbstractAdjointSensitivityAlgorithm
default_sensealg = setvjp(default_sensealg, ReverseDiffVJP())
end
DiffEqBase._concrete_solve_adjoint(prob, alg, default_sensealg, u0, p,
originator::SciMLBase.ADOriginator, args...; verbose,
originator, args...; verbose,
kwargs...)
end

Expand Down Expand Up @@ -1203,6 +1207,9 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{SciMLBase.AbstractDiscre
end
elseif prob isa
Union{SciMLBase.AbstractODEProblem, SciMLBase.AbstractSDEProblem}
@show typeof(_p)
@show typeof(repack(_p))
@show SciMLStructures.replace(Tunable(), p, _p) |> typeof
_f = function (u, p, t)
out = prob.f(u, p, t)
if out isa TrackedArray
Expand Down
3 changes: 1 addition & 2 deletions src/forward_sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,7 @@ function SciMLBase.remake(
{uType, tType, isinplace, P, F, K}
_p = p === nothing ? parameter_values(prob) : p
_f = f === nothing ? prob.f.f : f
initial_conditions = state_values(prob)
_u0 = u0 === nothing ? initial_conditions[1:(prob.f.numindvar)] : u0[1:(prob.f.numindvar)]
_u0 = u0 === nothing ? state_values(prob, 1:(prob.f.numindvar)) : u0[1:(prob.f.numindvar)]
_tspan = tspan === nothing ? prob.tspan : tspan
ODEForwardSensitivityProblem(_f, _u0,
_tspan, _p, prob.problem_type.sensealg;
Expand Down

0 comments on commit 18bcfa9

Please sign in to comment.