Skip to content

Commit

Permalink
Updated adjoint_oop.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-1Bhatt committed Jul 14, 2022
1 parent b18f9d3 commit 174812a
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions test/adjoint_oop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ du0, dp = adjoint_sensitivities(sol,Tsit5();dgdu_continuous=dg,g=G,

function G_p(p)
tmp_prob = remake(prob_nn,p=p)
sol = solve(tmp_prob,Tsit5(),abstol=1e-5,reltol=1e-5)
res,err = quadgk((t)-> (sum(sol_n(t) .- sol(t)).^2)./2,0.0,1.0,atol=1e-5,rtol=1e-5) # sol_n(t):numerical solution/data(above)
sol = solve(tmp_prob,Tsit5(),abstol=1e-12,reltol=1e-12)
res,err = quadgk((t)-> (sum(sol_n(t) .- sol(t)).^2)./2,0.0,1.0,atol=1e-12,rtol=1e-12) # sol_n(t):numerical solution/data(above)
res
end

function G_u(u0)
tmp_prob = remake(prob_nn,u0=u0)
sol = solve(tmp_prob,Tsit5(),abstol=1e-5,reltol=1e-5)
res,err = quadgk((t)-> (sum(sol_n(t) .- sol(t)).^2)./2,0.0,1.0,atol=1e-5,rtol=1e-5) # sol_n(t):numerical solution/data(above)
sol = solve(tmp_prob,Tsit5(),abstol=1e-12,reltol=1e-12)
res,err = quadgk((t)-> (sum(sol_n(t) .- sol(t)).^2)./2,0.0,1.0,atol=1e-12,rtol=1e-12) # sol_n(t):numerical solution/data(above)
res
end

Expand All @@ -97,7 +97,7 @@ n_dp = ForwardDiff.gradient(G_p,p_nn)
#concrete_solve

du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob_nn, Tsit5(), u0, p,
abstol = 1e-5, reltol = 1e-5,
abstol = 1e-12, reltol = 1e-12,
saveat = tsteps,
sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP()))),
u0, p_nn)
Expand Down

0 comments on commit 174812a

Please sign in to comment.