Skip to content

Commit

Permalink
bug fix, formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-1Bhatt committed Jul 20, 2022
1 parent a3747b4 commit 15fadaf
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 24 deletions.
12 changes: 9 additions & 3 deletions src/quadrature_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,13 @@ end
(dgdu_continuous === nothing && dgdp_continuous === nothing ||
g !== nothing))

λ = zero(u0)
if ArrayInterfaceCore.ismutable(u0)
len = length(u0)
λ = similar(u0, len)
λ .= false
else
λ = zero(u0)
end
sense = ODEQuadratureAdjointSensitivityFunction(g, sensealg, discrete, sol,
dgdu_continuous, dgdp_continuous)

Expand Down Expand Up @@ -334,10 +340,10 @@ function _adjoint_sensitivities(sol, sensealg::QuadratureAdjoint, alg; t = nothi
for i in (length(t) - 1):-1:1
if ArrayInterfaceCore.ismutable(res)
res .+= quadgk(integrand, t[i], t[i + 1],
atol = abstol, rtol = reltol)[1]
atol = abstol, rtol = reltol)[1]
else
res += quadgk(integrand, t[i], t[i + 1],
atol = abstol, rtol = reltol)[1]
atol = abstol, rtol = reltol)[1]
end
if t[i] == t[i + 1]
for cb in callback.discrete_callbacks
Expand Down
41 changes: 20 additions & 21 deletions test/adjoint_oop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ datasize = 20
tsteps = range(tspan[1], tspan[2], length = datasize)

function f(u, p, t)
p*u
p * u
end

prob = ODEProblem(f, u0, tspan, p)
sol = solve(prob, Tsit5(), saveat=tsteps, abstol = 1e-12, reltol = 1e-12)
sol = solve(prob, Tsit5(), saveat = tsteps, abstol = 1e-12, reltol = 1e-12)

## Discrete Case
dg_disc(u, p, t, i; outtype = nothing) = u .- 1
Expand All @@ -27,15 +27,15 @@ du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_dis
function G_p(p)
tmp_prob = remake(prob, p = p)
u = Array(solve(tmp_prob, Tsit5(), saveat = tsteps,
sensealg = SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12))
sensealg = SensitivityADPassThrough(), abstol = 1e-12, reltol = 1e-12))

return sum(((1 .- u) .^ 2) ./ 2)
end

function G_u(u0)
tmp_prob = remake(prob, u0 = u0)
u = Array(solve(tmp_prob, Tsit5(), saveat = tsteps,
sensealg = SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12))
sensealg = SensitivityADPassThrough(), abstol = 1e-12, reltol = 1e-12))
return sum(((1 .- u) .^ 2) ./ 2)
end

Expand All @@ -44,13 +44,13 @@ G_u(u0)
n_dp = ForwardDiff.gradient(G_p, p)
n_du0 = ForwardDiff.gradient(G_u, u0)

@test n_du0 du0 rtol = 1e-3
@test_broken n_dp dp' rtol = 1e-3
@test n_du0du0 rtol=1e-3
@test_broken n_dpdp' rtol=1e-3
@test sum(n_dp - dp') < 8.0

## Continuous Case

g(u, p, t) = sum((u.^2)./2)
g(u, p, t) = sum((u .^ 2) ./ 2)

function dg(u, p, t)
u
Expand All @@ -67,24 +67,24 @@ du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = g,
function G_p(p)
tmp_prob = remake(prob, p = p)
sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12)
res, err = quadgk((t) -> (sum((sol(t).^2)./2)), 0.0, 5.0, atol = 1e-12,
res, err = quadgk((t) -> (sum((sol(t) .^ 2) ./ 2)), 0.0, 5.0, atol = 1e-12,
rtol = 1e-12)
res
end

function G_u(u0)
tmp_prob = remake(prob, u0 = u0)
sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12)
res, err = quadgk((t) -> (sum((sol(t).^2)./2)), 0.0, 5.0, atol = 1e-12,
res, err = quadgk((t) -> (sum((sol(t) .^ 2) ./ 2)), 0.0, 5.0, atol = 1e-12,
rtol = 1e-12)
res
end

n_du0 = ForwardDiff.gradient(G_u, u0)
n_dp = ForwardDiff.gradient(G_p, p)

@test_broken n_du0 du0 rtol=1e-3
@test_broken n_dp dp' rtol=1e-3
@test_broken n_du0du0 rtol=1e-3
@test_broken n_dpdp' rtol=1e-3

@test sum(n_du0 - du0) < 1.0
@test sum(n_dp - dp) < 5.0
Expand All @@ -100,9 +100,6 @@ du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob, Tsit5(), u0, p,
@test !iszero(du0)
@test !iszero(dp)




##Neural ODE adjoint with SimpleChains
u0 = @SArray Float32[2.0, 0.0]
datasize = 30
Expand Down Expand Up @@ -158,8 +155,8 @@ G_u(u0)
n_dp = ForwardDiff.gradient(G_p, p_nn)
n_du0 = ForwardDiff.gradient(G_u, u0)

@test n_du0 du0 rtol = 1e-3
@test n_dp dp' rtol = 1e-3
@test n_du0du0 rtol=1e-3
@test n_dpdp' rtol=1e-3

## Continuous case

Expand All @@ -179,24 +176,26 @@ du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = G,
function G_p(p)
tmp_prob = remake(prob_nn, p = p)
sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12)
res, err = quadgk((t) -> (sum(((sol_n(t) .- sol(t)).^2)./2)), 0.0, 1.5, atol = 1e-12,
res, err = quadgk((t) -> (sum(((sol_n(t) .- sol(t)) .^ 2) ./ 2)), 0.0, 1.5,
atol = 1e-12,
rtol = 1e-12) # sol_n(t):numerical solution/data(above)
res
end

function G_u(u0)
tmp_prob = remake(prob_nn, u0 = u0)
sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12)
res, err = quadgk((t) -> (sum(((sol_n(t) .- sol(t)).^2)./2)), 0.0, 1.5, atol = 1e-12,
res, err = quadgk((t) -> (sum(((sol_n(t) .- sol(t)) .^ 2) ./ 2)), 0.0, 1.5,
atol = 1e-12,
rtol = 1e-12) # sol_n(t):numerical solution/data(above)
res
end

n_du0 = ForwardDiff.gradient(G_u, u0)
n_dp = ForwardDiff.gradient(G_p, p_nn)

@test n_du0 du0 rtol=1e-3
@test n_dp dp' rtol=1e-3
@test n_du0du0 rtol=1e-3
@test n_dpdp' rtol=1e-3

#concrete_solve

Expand All @@ -207,4 +206,4 @@ du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob_nn, Tsit5(), u0, p,
u0, p_nn)

@test !iszero(du0)
@test !iszero(dp)
@test !iszero(dp)

0 comments on commit 15fadaf

Please sign in to comment.