Skip to content

Commit

Permalink
handle shadowing cases
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Aug 25, 2022
1 parent 462402d commit 00c254b
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 12 deletions.
7 changes: 4 additions & 3 deletions src/backsolve_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,15 @@ end
else
transformed_function = StochasticTransformedFunction(sol, sol.prob.f, sol.prob.g,
corfunc_analytical)
drift_function = ODEFunction{false,true}(transformed_function)
drift_function = ODEFunction{false, true}(transformed_function)
sense_drift = ODEBacksolveSensitivityFunction(g, sensealg, discrete, sol,
dgdu_continuous, dgdp_continuous,
drift_function, alg)
end

diffusion_function = ODEFunction{isinplace(sol.prob),true}(sol.prob.g, jac = diffusion_jac,
paramjac = diffusion_paramjac)
diffusion_function = ODEFunction{isinplace(sol.prob), true}(sol.prob.g,
jac = diffusion_jac,
paramjac = diffusion_paramjac)
sense_diffusion = ODEBacksolveSensitivityFunction(g, sensealg, discrete, sol,
dgdu_continuous, dgdp_continuous,
diffusion_function, alg;
Expand Down
2 changes: 1 addition & 1 deletion src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::SciMLBase.AbstractODEProblem,
if haskey(kwargs, :callback)
error("Sensitivity analysis based on Least Squares Shadowing is not compatible with callbacks. Please select another `sensealg`.")
else
_prob = remake(prob, u0 = u0, p = p)
_prob = remake(prob, f = unwrapped_f(prob.f), u0 = u0, p = p)
end

sol = solve(_prob, alg, args...; save_start = save_start, save_end = save_end,
Expand Down
5 changes: 3 additions & 2 deletions src/interpolating_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,9 @@ end
abstol = abstol),
tspan = tspan)

diffusion_function = ODEFunction{isinplace(sol.prob),true}(sol.prob.g, jac = diffusion_jac,
paramjac = diffusion_paramjac)
diffusion_function = ODEFunction{isinplace(sol.prob), true}(sol.prob.g,
jac = diffusion_jac,
paramjac = diffusion_paramjac)
sense_diffusion = ODEInterpolatingAdjointSensitivityFunction(g, sensealg, discrete, sol,
dgdu_continuous,
dgdp_continuous,
Expand Down
6 changes: 3 additions & 3 deletions src/lss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ function LSSSensitivityFunction(sensealg, f, analytic, jac, jac_prototype, spars
paramjac, u0,
alg, p, f_cache, mm,
colorvec, tspan, g, dgdu, dgdp)
uf = DiffEqBase.UJacobianWrapper(f, tspan[1], p)
pf = DiffEqBase.ParamJacobianWrapper(f, tspan[1], copy(u0))
uf = DiffEqBase.UJacobianWrapper(unwrapped_f(f), tspan[1], p)
pf = DiffEqBase.ParamJacobianWrapper(unwrapped_f(f), tspan[1], copy(u0))

if DiffEqBase.has_jac(f)
jac_config = nothing
Expand Down Expand Up @@ -135,7 +135,7 @@ function ForwardLSSProblem(sol, sensealg::ForwardLSS;
dgdp_continuous = nothing,
g = sensealg.g,
kwargs...)
@unpack f, p, u0, tspan = sol.prob
@unpack p, u0, tspan = sol.prob

isinplace = DiffEqBase.isinplace(f)

Expand Down
3 changes: 2 additions & 1 deletion src/nilsas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ function NILSASProblem(sol, sensealg::NILSAS, alg;
t = nothing, dgdu_discrete = nothing, dgdp_discrete = nothing,
dgdu_continuous = nothing, dgdp_continuous = nothing, g = sensealg.g,
kwargs...)
@unpack f, p, u0, tspan = sol.prob
@unpack p, u0, tspan = sol.prob
@unpack nseg, nstep, rng, adjoint_sensealg, M = sensealg #number of segments on time interval, number of steps saved on each segment

f = unwrapped_f(sol.prob.f)
numindvar = length(u0)
numparams = length(p)

Expand Down
4 changes: 2 additions & 2 deletions src/nilss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ function NILSSSensitivityFunction(sensealg, f, u0, p, tspan, g, dgdu, dgdp,
dg_val .= false
end
else
pgpu = UGradientWrapper(g, tspan[1], p) # ∂g∂u
pgpp = ParamGradientWrapper(g, tspan[1], u0) #∂g∂p
pgpu = UGradientWrapper(unwrapped_f(g), tspan[1], p) # ∂g∂u
pgpp = ParamGradientWrapper(unwrapped_f(g), tspan[1], u0) #∂g∂p
pgpu_config = build_grad_config(sensealg, pgpu, u0, tspan[1])
pgpp_config = build_grad_config(sensealg, pgpp, u0, tspan[1])
dg_val = (similar(u0, numindvar), similar(u0, numparams))
Expand Down

0 comments on commit 00c254b

Please sign in to comment.