Skip to content

Commit

Permalink
oop dispatch for dgdu function, some tests, some corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-1Bhatt committed Jul 11, 2022
1 parent bdde708 commit 2c067ab
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 11 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "SparseArrays"]
test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "SimpleChains", "StaticArrays", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "SparseArrays"]
4 changes: 2 additions & 2 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ function (f::ReverseLossCallback)(integrator)
g(gᵤ, y, p, t[cur_time[]], cur_time[])
else
@assert sensealg isa QuadratureAdjoint
gᵤ = isq ? λ : @view(λ[1:idx])
gᵤ = g(gᵤ, y, p, t[cur_time[]], cur_time[])
outtype = DiffEqBase.parameterless_type)
gᵤ = g(y, p, t[cur_time[]], cur_time[];outtype=outtype)
end

if issemiexplicitdae
Expand Down
60 changes: 59 additions & 1 deletion src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,64 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg,
end
end

function df(u, p, t, i;outtype=nothing)
if only_end
eltype(Δ) <: NoTangent && return
if typeof(Δ) <: AbstractArray{<:AbstractArray} && length(Δ) == 1 && i == 1
# user did sol[end] on only_end
if typeof(_save_idxs) <: Number
x = vec(Δ[1])
_out = adapt(outtype, @view(x[_save_idxs]))
elseif _save_idxs isa Colon
_out = adapt(outtype, vec(Δ[1]))
else
_out = adapt(outtype,
vec(Δ[1])[_save_idxs])
end
else
Δ isa NoTangent && return
if typeof(_save_idxs) <: Number
x = vec(Δ)
_out = adapt(outtype, @view(x[_save_idxs]))
elseif _save_idxs isa Colon
_out = adapt(outtype, vec(Δ))
else
x = vec(Δ)
_out = adapt(outtype, @view(x[_save_idxs]))
end
end
else
!Base.isconcretetype(eltype(Δ)) &&
(Δ[i] isa NoTangent || eltype(Δ) <: NoTangent) && return
if typeof(Δ) <: AbstractArray{<:AbstractArray} || typeof(Δ) <: DESolution
x = Δ[i]
if typeof(_save_idxs) <: Number
_out = @view(x[_save_idxs])
elseif _save_idxs isa Colon
_out = vec(x)
else
_out = vec(@view(x[_save_idxs]))
end
else
if typeof(_save_idxs) <: Number
_out = adapt(outtype,
reshape(Δ, prod(size(Δ)[1:(end - 1)]),
size(Δ)[end])[_save_idxs, i])
elseif _save_idxs isa Colon
_out = vec(adapt(outtype,
reshape(Δ, prod(size(Δ)[1:(end - 1)]),
size(Δ)[end])[:, i]))
else
_out = vec(adapt(outtype,
reshape(Δ,
prod(size(Δ)[1:(end - 1)]),
size(Δ)[end])[:,i]))
end
end
end
return _out
end

if haskey(kwargs_adj, :callback_adj)
cb2 = CallbackSet(cb, kwargs[:callback_adj])
else
Expand Down Expand Up @@ -855,7 +913,7 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, sensealg::ReverseDiffAdjo

function reversediff_adjoint_forwardpass(_u0, _p)
if (convert_tspan(sensealg) === nothing &&
((haskey(kwargs, :callback) && has_continuous_callback(kwargs[:callback])))) ||
((haskey(kwargs, :callback) && has_a_callback(kwargs[:callback])))) ||
(convert_tspan(sensealg) !== nothing && convert_tspan(sensealg))
_tspan = convert.(eltype(_p), prob.tspan)
else
Expand Down
12 changes: 6 additions & 6 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -916,14 +916,14 @@ function accumulate_cost(dλ, y, p, t, S::TS,
@unpack dg, dg_val, g, g_grad_config = S.diffcache
if dg !== nothing
if !(dg isa Tuple)
dg(dg_val, y, p, t)
-= vec(dg_val)
dg_val = dg(y, p, t)
-= dg_val
else
dg[1](dg_val[1], y, p, t)
-= vec(dg_val[1])
dg[1](y, p, t)
-= dg_val
if dgrad !== nothing
dg[2](dg_val[2], y, p, t)
dgrad .-= vec(dg_val[2])
dg[2](y, p, t)
dgrad -= dg_val
end
end
else
Expand Down
75 changes: 74 additions & 1 deletion test/adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using SciMLSensitivity, OrdinaryDiffEq, RecursiveArrayTools, DiffEqBase,
ForwardDiff, Calculus, QuadGK, LinearAlgebra, Zygote
ForwardDiff, Calculus, QuadGK, LinearAlgebra, Zygote, SimpleChains, StaticArrays, Optimization, OptimizationFlux
using Test

function fb(du, u, p, t)
Expand Down Expand Up @@ -849,3 +849,76 @@ using LinearAlgebra, SciMLSensitivity, OrdinaryDiffEq, ForwardDiff, QuadGK
end
end
end

####Fully oop Adjoint

u0 = @SArray Float32[2.0, 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

function trueODE(u, p, t)
true_A = @SMatrix Float32[-0.1 2.0; -2.0 -0.1]
((u.^3)'true_A)'
end

prob = ODEProblem(trueODE, u0, tspan)
data = Array(solve(prob, Tsit5(), saveat = tsteps))

sc = SimpleChain(
static(2),
Activation(x -> x.^3),
TurboDense{true}(tanh, static(50)),
TurboDense{true}(identity, static(2))
)

p_nn = SimpleChains.init_params(sc)

df(u,p,t) = sc(u,p)

prob_nn = ODEProblem(df, u0, tspan, p_nn)
sol = solve(prob_nn, Tsit5();saveat=tsteps)
dg_disc(u, p, t, i;outtype=nothing) = data[:, i] .- u

res = adjoint_sensitivities(sol,Tsit5();t=tsteps[end],dg_discrete=dg_disc,
sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP()))

@test !iszero(res[1])
@test !iszero(res[2])

G(u,p,t) = sum(abs2, ((data.-u)./2))

function dg(u,p,t)
@show u
return data[:, end] .- u
end

res = adjoint_sensitivities(sol,Tsit5();dg_continuous=dg,g=G,
sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP()))

@test !iszero(res[1])
@test !iszero(res[2])

prob_nn = ODEProblem(f, u0, tspan)

function predict_neuralode(p)
Array(solve(prob_nn, Tsit5();p=p,saveat=tsteps,sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())))
end

function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, data .- pred)
return loss, pred
end

callback = function (p, l, pred; doplot = true)
display(l)
return false
end

optf = Optimization.OptimizationFunction((x,p)->loss_neuralode(x), Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optf, p_nn)

res = Optimization.solve(optprob, ADAM(0.05),callback=callback,maxiters=300)

@test loss_neuralode(res.u) < 0.8

0 comments on commit 2c067ab

Please sign in to comment.