Skip to content

Commit

Permalink
OOP Adjoint on Numerical solve
Browse files Browse the repository at this point in the history
  • Loading branch information
ba2tro committed Jul 20, 2022
1 parent 9140817 commit a3747b4
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 33 deletions.
7 changes: 6 additions & 1 deletion src/quadrature_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,13 @@ function _adjoint_sensitivities(sol, sensealg::QuadratureAdjoint, alg; t = nothi
end

for i in (length(t) - 1):-1:1
res .+= quadgk(integrand, t[i], t[i + 1],
if ArrayInterfaceCore.ismutable(res)
res .+= quadgk(integrand, t[i], t[i + 1],
atol = abstol, rtol = reltol)[1]
else
res += quadgk(integrand, t[i], t[i + 1],
atol = abstol, rtol = reltol)[1]
end
if t[i] == t[i + 1]
for cb in callback.discrete_callbacks
if t[i] cb.affect!.event_times
Expand Down
133 changes: 101 additions & 32 deletions test/adjoint_oop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,108 @@ using SciMLSensitivity, OrdinaryDiffEq, SimpleChains, StaticArrays, QuadGK, Forw
Zygote
using Test

##Adjoints of numerical solve

u0 = @SVector [1.0f0, 1.0f0]
p = @SMatrix [1.5f0 -1.0f0; 3.0f0 -1.0f0]
tspan = [0.0f0, 5.0f0]
datasize = 20
tsteps = range(tspan[1], tspan[2], length = datasize)

function f(u, p, t)
p*u
end

prob = ODEProblem(f, u0, tspan, p)
sol = solve(prob, Tsit5(), saveat=tsteps, abstol = 1e-12, reltol = 1e-12)

## Discrete Case
dg_disc(u, p, t, i; outtype = nothing) = u .- 1

du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc,
sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP()))

## with ForwardDiff
function G_p(p)
tmp_prob = remake(prob, p = p)
u = Array(solve(tmp_prob, Tsit5(), saveat = tsteps,
sensealg = SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12))

return sum(((1 .- u) .^ 2) ./ 2)
end

function G_u(u0)
tmp_prob = remake(prob, u0 = u0)
u = Array(solve(tmp_prob, Tsit5(), saveat = tsteps,
sensealg = SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12))
return sum(((1 .- u) .^ 2) ./ 2)
end

G_p(p)
G_u(u0)
n_dp = ForwardDiff.gradient(G_p, p)
n_du0 = ForwardDiff.gradient(G_u, u0)

@test n_du0 du0 rtol = 1e-3
@test_broken n_dp dp' rtol = 1e-3
@test sum(n_dp - dp') < 8.0

## Continuous Case

g(u, p, t) = sum((u.^2)./2)

function dg(u, p, t)
u
end

du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = g,
sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP()))

@test !iszero(du0)
@test !iszero(dp)

##numerical

function G_p(p)
tmp_prob = remake(prob, p = p)
sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12)
res, err = quadgk((t) -> (sum((sol(t).^2)./2)), 0.0, 5.0, atol = 1e-12,
rtol = 1e-12)
res
end

function G_u(u0)
tmp_prob = remake(prob, u0 = u0)
sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12)
res, err = quadgk((t) -> (sum((sol(t).^2)./2)), 0.0, 5.0, atol = 1e-12,
rtol = 1e-12)
res
end

n_du0 = ForwardDiff.gradient(G_u, u0)
n_dp = ForwardDiff.gradient(G_p, p)

@test_broken n_du0 du0 rtol=1e-3
@test_broken n_dp dp' rtol=1e-3

@test sum(n_du0 - du0) < 1.0
@test sum(n_dp - dp) < 5.0

## concrete solve

du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob, Tsit5(), u0, p,
abstol = 1e-6, reltol = 1e-6,
saveat = tsteps,
sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP()))),
u0, p)

@test !iszero(du0)
@test !iszero(dp)


#####


##Neural ODE adjoint with SimpleChains
u0 = @SArray Float32[2.0, 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
Expand Down Expand Up @@ -108,34 +207,4 @@ du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob_nn, Tsit5(), u0, p,
u0, p_nn)

@test !iszero(du0)
@test !iszero(dp)

#####Delete################################################################
using Flux

u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

function trueODEfunc(u, p, t)
true_A = [-0.1f0 2.0f0; -2.0f0 -0.1f0]
return ((u.^3)'true_A)'
end

prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

dudt2 = Chain((x) -> x.^3,
Dense(2, 50, tanh),
Dense(50, 2))
p, re = Flux.destructure(dudt2)
f(u, p, t) = re(p)(u)

prob_nn = ODEProblem(f, u0, tspan)

du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob_nn, Tsit5(), u0, p,
abstol = 1e-12, reltol = 1e-12,
saveat = tsteps,
sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP()))),
u0, p)
@test !iszero(dp)

0 comments on commit a3747b4

Please sign in to comment.