diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 0fd6045d0..35ca3c13b 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -44,7 +44,7 @@ DifferentiationInterfaceTrackerExt = "Tracker" DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] [compat] -ADTypes = "1.2.0" +ADTypes = "1.5.0" ChainRulesCore = "1.23.0" Compat = "3,4" Diffractor = "=0.2.6" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index 89d091b51..5bf6f15ee 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -42,9 +42,11 @@ struct ReverseDiffGradientExtras{T} <: GradientExtras tape::T end -function DI.prepare_gradient(f, backend::AutoReverseDiff, x::AbstractArray) +function DI.prepare_gradient( + f, ::AutoReverseDiff{Compile}, x::AbstractArray +) where {Compile} tape = GradientTape(f, x) - if backend.compile + if Compile tape = compile(tape) end return ReverseDiffGradientExtras(tape) @@ -91,9 +93,11 @@ struct ReverseDiffOneArgJacobianExtras{T} <: JacobianExtras tape::T end -function DI.prepare_jacobian(f, backend::AutoReverseDiff, x::AbstractArray) +function DI.prepare_jacobian( + f, ::AutoReverseDiff{Compile}, x::AbstractArray +) where {Compile} tape = JacobianTape(f, x) - if backend.compile + if Compile tape = compile(tape) end return ReverseDiffOneArgJacobianExtras(tape) @@ -140,9 +144,9 @@ struct ReverseDiffHessianExtras{T} <: HessianExtras tape::T end -function DI.prepare_hessian(f, backend::AutoReverseDiff, x::AbstractArray) +function DI.prepare_hessian(f, ::AutoReverseDiff{Compile}, x::AbstractArray) where {Compile} tape = HessianTape(f, x) - if backend.compile + if Compile tape = compile(tape) end return ReverseDiffHessianExtras(tape) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl index 5898fda99..b45a7e211 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl @@ -72,10 +72,10 @@ struct ReverseDiffTwoArgJacobianExtras{T} <: JacobianExtras end function DI.prepare_jacobian( - f!, y::AbstractArray, backend::AutoReverseDiff, x::AbstractArray -) + f!, y::AbstractArray, ::AutoReverseDiff{Compile}, x::AbstractArray +) where {Compile} tape = JacobianTape(f!, y, x) - if backend.compile + if Compile tape = compile(tape) end return ReverseDiffTwoArgJacobianExtras(tape)