Skip to content

Commit

Permalink
Batched HVP (#330)
Browse files Browse the repository at this point in the history
* Batched hvp

* Coverage
  • Loading branch information
gdalle committed Jun 25, 2024
1 parent 86f1e02 commit b8f82b0
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 68 deletions.
40 changes: 18 additions & 22 deletions DifferentiationInterface/src/first_order/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,29 +140,25 @@ end

### Batched

function prepare_pullback_batched(
f::F, backend::AbstractADType, x, dy::Batch{B}
) where {F,B}
function prepare_pullback_batched(f::F, backend::AbstractADType, x, dy::Batch) where {F}
return prepare_pullback(f, backend, x, first(dy.elements))
end

function prepare_pullback_batched(
f!::F, y, backend::AbstractADType, x, dy::Batch{B}
) where {F,B}
function prepare_pullback_batched(f!::F, y, backend::AbstractADType, x, dy::Batch) where {F}
return prepare_pullback(f!, y, backend, x, first(dy.elements))
end

### Batched, same point

function prepare_pullback_batched_same_point(
f::F, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
) where {F,B}
f::F, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
) where {F}
return prepare_pullback_same_point(f, backend, x, first(dy.elements), extras)
end

function prepare_pullback_batched_same_point(
f!::F, y, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
) where {F,B}
f!::F, y, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
) where {F}
return prepare_pullback_same_point(f!, y, backend, x, first(dy.elements), extras)
end

Expand Down Expand Up @@ -229,17 +225,17 @@ end
function pullback_batched(
f::F, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
) where {F,B}
dx_elements = ntuple(Val(B)) do l
pullback(f, backend, x, dy.elements[l], extras)
dx_elements = ntuple(Val(B)) do b
pullback(f, backend, x, dy.elements[b], extras)
end
return Batch(dx_elements)
end

function pullback_batched!(
f::F, dx::Batch{B}, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
) where {F,B}
for l in 1:B
pullback!(f, dx.elements[l], backend, x, dy.elements[l], extras)
f::F, dx::Batch, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
) where {F}
for b in eachindex(dx.elements, dy.elements)
pullback!(f, dx.elements[b], backend, x, dy.elements[b], extras)
end
return dx
end
Expand Down Expand Up @@ -307,17 +303,17 @@ end
function pullback_batched(
f!::F, y, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
) where {F,B}
dx_elements = ntuple(Val(B)) do l
pullback(f!, y, backend, x, dy.elements[l], extras)
dx_elements = ntuple(Val(B)) do b
pullback(f!, y, backend, x, dy.elements[b], extras)
end
return Batch(dx_elements)
end

function pullback_batched!(
f!::F, y, dx::Batch{B}, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
) where {F,B}
for l in 1:B
pullback!(f!, y, dx.elements[l], backend, x, dy.elements[l], extras)
f!::F, y, dx::Batch, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
) where {F}
for b in eachindex(dx.elements, dy.elements)
pullback!(f!, y, dx.elements[b], backend, x, dy.elements[b], extras)
end
return dx
end
46 changes: 19 additions & 27 deletions DifferentiationInterface/src/first_order/pushforward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,29 +141,27 @@ end

### Batched

function prepare_pushforward_batched(
f::F, backend::AbstractADType, x, dx::Batch{B}
) where {F,B}
function prepare_pushforward_batched(f::F, backend::AbstractADType, x, dx::Batch) where {F}
return prepare_pushforward(f, backend, x, first(dx.elements))
end

function prepare_pushforward_batched(
f!::F, y, backend::AbstractADType, x, dx::Batch{B}
) where {F,B}
f!::F, y, backend::AbstractADType, x, dx::Batch
) where {F}
return prepare_pushforward(f!, y, backend, x, first(dx.elements))
end

### Batched, same point

function prepare_pushforward_batched_same_point(
f::F, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras
) where {F,B}
f::F, backend::AbstractADType, x, dx::Batch, extras::PushforwardExtras
) where {F}
return prepare_pushforward_same_point(f, backend, x, first(dx.elements), extras)
end

function prepare_pushforward_batched_same_point(
f!::F, y, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras
) where {F,B}
f!::F, y, backend::AbstractADType, x, dx::Batch, extras::PushforwardExtras
) where {F}
return prepare_pushforward_same_point(f!, y, backend, x, first(dx.elements), extras)
end

Expand Down Expand Up @@ -234,17 +232,17 @@ end
function pushforward_batched(
f::F, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras
) where {F,B}
dy_elements = ntuple(Val(B)) do l
pushforward(f, backend, x, dx.elements[l], extras)
dy_elements = ntuple(Val(B)) do b
pushforward(f, backend, x, dx.elements[b], extras)
end
return Batch(dy_elements)
end

function pushforward_batched!(
f::F, dy::Batch{B}, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras
) where {F,B}
for l in 1:B
pushforward!(f, dy.elements[l], backend, x, dx.elements[l], extras)
f::F, dy::Batch, backend::AbstractADType, x, dx::Batch, extras::PushforwardExtras
) where {F}
for b in eachindex(dy.elements, dx.elements)
pushforward!(f, dy.elements[b], backend, x, dx.elements[b], extras)
end
return dy
end
Expand Down Expand Up @@ -316,23 +314,17 @@ end
function pushforward_batched(
f!::F, y, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras
) where {F,B}
dy_elements = ntuple(Val(B)) do l
pushforward(f!, y, backend, x, dx.elements[l], extras)
dy_elements = ntuple(Val(B)) do b
pushforward(f!, y, backend, x, dx.elements[b], extras)
end
return Batch(dy_elements)
end

function pushforward_batched!(
f!::F,
y,
dy::Batch{B},
backend::AbstractADType,
x,
dx::Batch{B},
extras::PushforwardExtras,
) where {F,B}
for l in 1:B
pushforward!(f!, y, dy.elements[l], backend, x, dx.elements[l], extras)
f!::F, y, dy::Batch, backend::AbstractADType, x, dx::Batch, extras::PushforwardExtras
) where {F}
for b in eachindex(dy.elements, dx.elements)
pushforward!(f!, y, dy.elements[b], backend, x, dx.elements[b], extras)
end
return dy
end
142 changes: 123 additions & 19 deletions DifferentiationInterface/src/second_order/hvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,33 +102,33 @@ function prepare_hvp(f::F, backend::AbstractADType, x, dx) where {F}
end

function prepare_hvp(f::F, backend::SecondOrder, x, dx) where {F}
return prepare_hvp(f, backend, x, dx, hvp_mode(backend))
return prepare_hvp_aux(f, backend, x, dx, hvp_mode(backend))
end

function prepare_hvp(f::F, backend::SecondOrder, x, dx, ::ForwardOverForward) where {F}
function prepare_hvp_aux(f::F, backend::SecondOrder, x, dx, ::ForwardOverForward) where {F}
# pushforward of many pushforwards in theory, but pushforward of gradient in practice
inner_gradient = InnerGradient(f, nested(inner(backend)))
outer_pushforward_extras = prepare_pushforward(inner_gradient, outer(backend), x, dx)
return ForwardOverForwardHVPExtras(inner_gradient, outer_pushforward_extras)
end

function prepare_hvp(f::F, backend::SecondOrder, x, dx, ::ForwardOverReverse) where {F}
function prepare_hvp_aux(f::F, backend::SecondOrder, x, dx, ::ForwardOverReverse) where {F}
# pushforward of gradient
inner_gradient = InnerGradient(f, nested(inner(backend)))
outer_pushforward_extras = prepare_pushforward(inner_gradient, outer(backend), x, dx)
return ForwardOverReverseHVPExtras(inner_gradient, outer_pushforward_extras)
end

function prepare_hvp(f::F, backend::SecondOrder, x, dx, ::ReverseOverForward) where {F}
function prepare_hvp_aux(f::F, backend::SecondOrder, x, dx, ::ReverseOverForward) where {F}
# gradient of pushforward
# uses dx in the closure so it can't be stored
inner_pushforward = InnerPushforwardFixedSeed(f, nested(inner(backend)), dx)
outer_gradient_extras = prepare_gradient(inner_pushforward, outer(backend), x)
return ReverseOverForwardHVPExtras(outer_gradient_extras)
end

function prepare_hvp(f::F, backend::SecondOrder, x, dx, ::ReverseOverReverse) where {F}
# pullback of the gradient
function prepare_hvp_aux(f::F, backend::SecondOrder, x, dx, ::ReverseOverReverse) where {F}
# pullback of gradient
inner_gradient = InnerGradient(f, nested(inner(backend)))
outer_pullback_extras = prepare_pullback(inner_gradient, outer(backend), x, dx)
return ReverseOverReverseHVPExtras(inner_gradient, outer_pullback_extras)
Expand All @@ -149,16 +149,58 @@ end

### Batched

function prepare_hvp_batched(f::F, backend::AbstractADType, x, dx::Batch{B}) where {F,B}
return prepare_hvp(f, backend, x, first(dx.elements))
function prepare_hvp_batched(f::F, backend::AbstractADType, x, dx::Batch) where {F}
return prepare_hvp_batched(f, SecondOrder(backend, backend), x, dx)
end

function prepare_hvp_batched(f::F, backend::SecondOrder, x, dx::Batch) where {F}
return prepare_hvp_batched_aux(f, backend, x, dx, hvp_mode(backend))
end

function prepare_hvp_batched_aux(
f::F, backend::SecondOrder, x, dx::Batch, ::ForwardOverForward
) where {F}
# batched pushforward of gradient
inner_gradient = InnerGradient(f, nested(inner(backend)))
outer_pushforward_extras = prepare_pushforward_batched(
inner_gradient, outer(backend), x, dx
)
return ForwardOverForwardHVPExtras(inner_gradient, outer_pushforward_extras)
end

function prepare_hvp_batched_aux(
f::F, backend::SecondOrder, x, dx::Batch, ::ForwardOverReverse
) where {F}
# batched pushforward of gradient
inner_gradient = InnerGradient(f, nested(inner(backend)))
outer_pushforward_extras = prepare_pushforward_batched(
inner_gradient, outer(backend), x, dx
)
return ForwardOverReverseHVPExtras(inner_gradient, outer_pushforward_extras)
end

function prepare_hvp_batched_aux(
f::F, backend::SecondOrder, x, dx::Batch, ::ReverseOverForward
) where {F}
# TODO: batched version replacing the outer gradient with a pullback
return prepare_hvp_aux(f, backend, x, first(dx.elements), ReverseOverForward())
end

function prepare_hvp_batched_aux(
f::F, backend::SecondOrder, x, dx::Batch, ::ReverseOverReverse
) where {F}
# batched pullback of gradient
inner_gradient = InnerGradient(f, nested(inner(backend)))
outer_pullback_extras = prepare_pullback_batched(inner_gradient, outer(backend), x, dx)
return ReverseOverReverseHVPExtras(inner_gradient, outer_pullback_extras)
end

### Batched, same point

function prepare_hvp_batched_same_point(
f::F, backend::AbstractADType, x, dx::Batch{B}, extras::HVPExtras
) where {F,B}
return prepare_hvp_same_point(f, backend, x, first(dx.elements), extras)
f::F, backend::AbstractADType, x, dx::Batch, extras::HVPExtras
) where {F}
return extras
end

## One argument
Expand Down Expand Up @@ -241,27 +283,89 @@ end

### Batched

function hvp_batched(f::F, backend::AbstractADType, x, dx, extras::HVPExtras) where {F}
function hvp_batched(
f::F, backend::AbstractADType, x, dx::Batch, extras::HVPExtras
) where {F}
return hvp_batched(f, SecondOrder(backend, backend), x, dx, extras)
end

function hvp_batched(
f::F, backend::SecondOrder, x, dx::Batch{B}, extras::HVPExtras
f::F, backend::SecondOrder, x, dx::Batch, extras::ForwardOverForwardHVPExtras
) where {F}
@compat (; inner_gradient, outer_pushforward_extras) = extras
return pushforward_batched(
inner_gradient, outer(backend), x, dx, outer_pushforward_extras
)
end

function hvp_batched(
f::F, backend::SecondOrder, x, dx::Batch, extras::ForwardOverReverseHVPExtras
) where {F}
@compat (; inner_gradient, outer_pushforward_extras) = extras
return pushforward_batched(
inner_gradient, outer(backend), x, dx, outer_pushforward_extras
)
end

function hvp_batched(
f::F, backend::SecondOrder, x, dx::Batch{B}, extras::ReverseOverForwardHVPExtras
) where {F,B}
dg_elements = ntuple(Val(B)) do l
hvp(f, backend, x, dx.elements[l], extras)
dg_elements = ntuple(Val(B)) do b
hvp(f, backend, x, dx.elements[b], extras)
end
return Batch(dg_elements)
end

function hvp_batched!(f::F, dg, backend::AbstractADType, x, dx, extras::HVPExtras) where {F}
function hvp_batched(
f::F, backend::SecondOrder, x, dx::Batch, extras::ReverseOverReverseHVPExtras
) where {F}
@compat (; inner_gradient, outer_pullback_extras) = extras
return pullback_batched(inner_gradient, outer(backend), x, dx, outer_pullback_extras)
end

function hvp_batched!(
f::F, dg::Batch, backend::AbstractADType, x, dx::Batch, extras::HVPExtras
) where {F}
return hvp_batched!(f, dg, SecondOrder(backend, backend), x, dx, extras)
end

function hvp_batched!(
f::F, dg::Batch{B}, backend::SecondOrder, x, dx::Batch{B}, extras::HVPExtras
f::F, dg::Batch, backend::SecondOrder, x, dx::Batch, extras::ForwardOverForwardHVPExtras
) where {F}
@compat (; inner_gradient, outer_pushforward_extras) = extras
return pushforward_batched!(
inner_gradient, dg, outer(backend), x, dx, outer_pushforward_extras
)
end

function hvp_batched!(
f::F, dg::Batch, backend::SecondOrder, x, dx::Batch, extras::ForwardOverReverseHVPExtras
) where {F}
@compat (; inner_gradient, outer_pushforward_extras) = extras
return pushforward_batched!(
inner_gradient, dg, outer(backend), x, dx, outer_pushforward_extras
)
end

function hvp_batched!(
f::F,
dg::Batch{B},
backend::SecondOrder,
x,
dx::Batch{B},
extras::ReverseOverForwardHVPExtras,
) where {F,B}
for l in 1:B
hvp!(f, dg.elements[l], backend, x, dx.elements[l], extras)
for b in eachindex(dg.elements, dx.elements)
hvp!(f, dg.elements[b], backend, x, dx.elements[b], extras)
end
return dg
end

function hvp_batched!(
f::F, dg::Batch, backend::SecondOrder, x, dx::Batch, extras::ReverseOverReverseHVPExtras
) where {F}
@compat (; inner_gradient, outer_pullback_extras) = extras
return pullback_batched!(
inner_gradient, dg, outer(backend), x, dx, outer_pullback_extras
)
end

0 comments on commit b8f82b0

Please sign in to comment.