Skip to content

Commit

Permalink
Use new ReverseDiff compile type parameter (#351)
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasschmitz committed Jul 13, 2024
1 parent e79bc20 commit f27415c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]

[compat]
ADTypes = "1.2.0"
ADTypes = "1.5.0"
ChainRulesCore = "1.23.0"
Compat = "3,4"
Diffractor = "=0.2.6"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ struct ReverseDiffGradientExtras{T} <: GradientExtras
tape::T
end

function DI.prepare_gradient(f, backend::AutoReverseDiff, x::AbstractArray)
function DI.prepare_gradient(
f, ::AutoReverseDiff{Compile}, x::AbstractArray
) where {Compile}
tape = GradientTape(f, x)
if backend.compile
if Compile
tape = compile(tape)
end
return ReverseDiffGradientExtras(tape)
Expand Down Expand Up @@ -91,9 +93,11 @@ struct ReverseDiffOneArgJacobianExtras{T} <: JacobianExtras
tape::T
end

function DI.prepare_jacobian(f, backend::AutoReverseDiff, x::AbstractArray)
function DI.prepare_jacobian(
f, ::AutoReverseDiff{Compile}, x::AbstractArray
) where {Compile}
tape = JacobianTape(f, x)
if backend.compile
if Compile
tape = compile(tape)
end
return ReverseDiffOneArgJacobianExtras(tape)
Expand Down Expand Up @@ -140,9 +144,9 @@ struct ReverseDiffHessianExtras{T} <: HessianExtras
tape::T
end

function DI.prepare_hessian(f, backend::AutoReverseDiff, x::AbstractArray)
function DI.prepare_hessian(f, ::AutoReverseDiff{Compile}, x::AbstractArray) where {Compile}
tape = HessianTape(f, x)
if backend.compile
if Compile
tape = compile(tape)
end
return ReverseDiffHessianExtras(tape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ struct ReverseDiffTwoArgJacobianExtras{T} <: JacobianExtras
end

function DI.prepare_jacobian(
f!, y::AbstractArray, backend::AutoReverseDiff, x::AbstractArray
)
f!, y::AbstractArray, ::AutoReverseDiff{Compile}, x::AbstractArray
) where {Compile}
tape = JacobianTape(f!, y, x)
if backend.compile
if Compile
tape = compile(tape)
end
return ReverseDiffTwoArgJacobianExtras(tape)
Expand Down

0 comments on commit f27415c

Please sign in to comment.