Skip to content

Commit

Permalink
StaticArrays AD, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-1Bhatt committed Aug 12, 2022
1 parent 8c59e9a commit d354ba7
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 150 deletions.
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down Expand Up @@ -88,9 +89,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
4 changes: 3 additions & 1 deletion src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import ZygoteRules, Zygote, ReverseDiff
import ArrayInterfaceCore, ArrayInterfaceTracker
import Enzyme
import GPUArraysCore
using StaticArrays

import PreallocationTools: dualcache, get_tmp, DiffCache

Expand All @@ -24,7 +25,7 @@ using EllipsisNotation
using Markdown

using Reexport
import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented
import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ProjectTo, project_type, _eltype_projectto, rrule
abstract type SensitivityFunction end
abstract type TransformedFunction end

Expand All @@ -45,6 +46,7 @@ include("concrete_solve.jl")
include("second_order.jl")
include("steadystate_adjoint.jl")
include("sde_tools.jl")
include("staticarrays.jl")

# AD Extensions
include("reversediff.jl")
Expand Down
3 changes: 1 addition & 2 deletions src/quadrature_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,7 @@ end
function (S::AdjointSensitivityIntegrand)(out, t)
@unpack y, λ, pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol, adj_sol = S
f = sol.prob.f
# if eltype(sol.u) <: StaticArrays.SArray
if ArrayInterfaceCore.ismutable(eltype(sol.u))
if ArrayInterfaceCore.ismutable(y)
sol(y, t)
adj_sol(λ, t)
else
Expand Down
23 changes: 23 additions & 0 deletions src/staticarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
### Projecting a tuple to SMatrix leads to ChainRulesCore._projection_mismatch by default, so overloaded here
function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::StaticArrays.SArray)
dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements))
return project_type(project)(dz...)
end

### Project SArray to SArray
function ProjectTo(x::StaticArrays.SArray{S,T}) where {S, T}
return ProjectTo{StaticArrays.SArray}(; element=_eltype_projectto(T), axes=S)
end

function (project::ProjectTo{StaticArrays.SArray})(dx::AbstractArray{S,M}) where {S,M}
return StaticArrays.SArray{project.axes}(dx)
end

### Adjoint for SArray constructor

function rrule(::Type{T}, x::Tuple) where {T<:StaticArrays.SArray}
project_x = ProjectTo(x)
Array_pullback(ȳ) = (NoTangent(), project_x(ȳ))
return T(x), Array_pullback
end
Loading

0 comments on commit d354ba7

Please sign in to comment.