Skip to content

Commit

Permalink
Rename p to dg for Hessian-vector product (#318)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jun 17, 2024
1 parent b41db2f commit fa551e5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 24 deletions.
30 changes: 15 additions & 15 deletions DifferentiationInterface/src/second_order/hvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ Create an `extras_same` object that can be given to [`hvp`](@ref) and its varian
function prepare_hvp_same_point end

"""
hvp(f, backend, x, dx, [extras]) -> p
hvp(f, backend, x, dx, [extras]) -> dg
Compute the Hessian-vector product of `f` at point `x` with seed `dx`.
"""
function hvp end

"""
hvp!(f, p, backend, x, dx, [extras]) -> p
hvp!(f, dg, backend, x, dx, [extras]) -> dg
Compute the Hessian-vector product of `f` at point `x` with seed `dx`, overwriting `p`.
Compute the Hessian-vector product of `f` at point `x` with seed `dx`, overwriting `dg`.
"""
function hvp! end

Expand Down Expand Up @@ -141,8 +141,8 @@ function hvp(f::F, backend::AbstractADType, x, dx) where {F}
return hvp(f, backend, x, dx, prepare_hvp(f, backend, x, dx))
end

function hvp!(f::F, p, backend::AbstractADType, x, dx) where {F}
return hvp!(f, p, backend, x, dx, prepare_hvp(f, backend, x, dx))
function hvp!(f::F, dg, backend::AbstractADType, x, dx) where {F}
return hvp!(f, dg, backend, x, dx, prepare_hvp(f, backend, x, dx))
end

function hvp(f::F, backend::AbstractADType, x, dx, extras::HVPExtras) where {F}
Expand Down Expand Up @@ -178,35 +178,35 @@ function hvp(
return pullback(inner_gradient, outer(backend), x, dx, outer_pullback_extras)
end

function hvp!(f::F, p, backend::AbstractADType, x, dx, extras::HVPExtras) where {F}
return hvp!(f, p, SecondOrder(backend, backend), x, dx, extras)
function hvp!(f::F, dg, backend::AbstractADType, x, dx, extras::HVPExtras) where {F}
return hvp!(f, dg, SecondOrder(backend, backend), x, dx, extras)
end

function hvp!(
f::F, p, backend::SecondOrder, x, dx, extras::ForwardOverForwardHVPExtras
f::F, dg, backend::SecondOrder, x, dx, extras::ForwardOverForwardHVPExtras
) where {F}
@compat (; inner_gradient, outer_pushforward_extras) = extras
return pushforward!(inner_gradient, p, outer(backend), x, dx, outer_pushforward_extras)
return pushforward!(inner_gradient, dg, outer(backend), x, dx, outer_pushforward_extras)
end

function hvp!(
f::F, p, backend::SecondOrder, x, dx, extras::ForwardOverReverseHVPExtras
f::F, dg, backend::SecondOrder, x, dx, extras::ForwardOverReverseHVPExtras
) where {F}
@compat (; inner_gradient, outer_pushforward_extras) = extras
return pushforward!(inner_gradient, p, outer(backend), x, dx, outer_pushforward_extras)
return pushforward!(inner_gradient, dg, outer(backend), x, dx, outer_pushforward_extras)
end

function hvp!(
f::F, p, backend::SecondOrder, x, dx, extras::ReverseOverForwardHVPExtras
f::F, dg, backend::SecondOrder, x, dx, extras::ReverseOverForwardHVPExtras
) where {F}
@compat (; outer_gradient_extras) = extras
inner_pushforward = InnerPushforwardFixedSeed(f, nested(inner(backend)), dx)
return gradient!(inner_pushforward, p, outer(backend), x, outer_gradient_extras)
return gradient!(inner_pushforward, dg, outer(backend), x, outer_gradient_extras)
end

function hvp!(
f::F, p, backend::SecondOrder, x, dx, extras::ReverseOverReverseHVPExtras
f::F, dg, backend::SecondOrder, x, dx, extras::ReverseOverReverseHVPExtras
) where {F}
@compat (; inner_gradient, outer_pullback_extras) = extras
return pullback!(inner_gradient, p, outer(backend), x, dx, outer_pullback_extras)
return pullback!(inner_gradient, dg, outer(backend), x, dx, outer_pullback_extras)
end
2 changes: 1 addition & 1 deletion DifferentiationInterfaceTest/src/tests/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ function run_benchmark!(
# benchmark
extras = prepare_hvp(f, ba, x, dx)
bench0 = @be prepare_hvp(f, ba, x, dx) samples = 1 evals = 1
bench1 = @be (p=mysimilar(x), ext=deepcopy(extras)) hvp!(f, _.p, ba, x, dx, _.ext) evals = 1
bench1 = @be (dg=mysimilar(x), ext=deepcopy(extras)) hvp!(f, _.dg, ba, x, dx, _.ext) evals = 1
# count
cc = CallCounter(f)
extras = prepare_hvp(cc, ba, x, dx)
Expand Down
16 changes: 8 additions & 8 deletions DifferentiationInterfaceTest/src/tests/correctness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ function test_correctness(
ref_backend,
)
@compat (; f, x, dx) = new_scen = deepcopy(scen)
p_true = if ref_backend isa AbstractADType
dg_true = if ref_backend isa AbstractADType
hvp(f, ref_backend, x, dx)
else
new_scen.ref(x, dx)
Expand All @@ -921,14 +921,14 @@ function test_correctness(
(prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)),),
(prepare_hvp_same_point(f, ba, x, mycopy_random(dx)),),
])
p1 = hvp(f, ba, x, dx, extras_tup...)
dg1 = hvp(f, ba, x, dx, extras_tup...)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test isempty(extras_tup) || only(extras_tup) isa HVPExtras
end
@testset "HVP value" begin
@test p1 p_true
@test dg1 dg_true
end
end
end
Expand All @@ -945,7 +945,7 @@ function test_correctness(
ref_backend,
)
@compat (; f, x, dx) = new_scen = deepcopy(scen)
p_true = if ref_backend isa AbstractADType
dg_true = if ref_backend isa AbstractADType
hvp(f, ref_backend, x, dx)
else
new_scen.ref(x, dx)
Expand All @@ -956,16 +956,16 @@ function test_correctness(
(prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)),),
(prepare_hvp_same_point(f, ba, x, mycopy_random(dx)),),
])
p1_in = mysimilar(x)
p1 = hvp!(f, p1_in, ba, x, dx, extras_tup...)
dg1_in = mysimilar(x)
dg1 = hvp!(f, dg1_in, ba, x, dx, extras_tup...)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test isempty(extras_tup) || only(extras_tup) isa HVPExtras
end
@testset "HVP value" begin
@test p1_in p_true
@test p1 p_true
@test dg1_in dg_true
@test dg1 dg_true
end
end
end
Expand Down

0 comments on commit fa551e5

Please sign in to comment.