Skip to content

Commit

Permalink
Better seed handling in Jacobian and Hessian (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jun 26, 2024
1 parent 15e089d commit 8dc755b
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 318 deletions.
96 changes: 36 additions & 60 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ abstract type JacobianExtras <: Extras end
struct NoJacobianExtras <: JacobianExtras end

struct PushforwardJacobianExtras{B,D,E<:PushforwardExtras,Y} <: JacobianExtras
seeds::D
batched_seeds::Vector{Batch{B,D}}
pushforward_batched_extras::E
y_example::Y
end

struct PullbackJacobianExtras{B,D,E<:PullbackExtras,Y} <: JacobianExtras
seeds::D
batched_seeds::Vector{Batch{B,D}}
pullback_batched_extras::E
y_example::Y
end
Expand All @@ -80,26 +80,38 @@ function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardFast) wh
N = length(x)
B = pick_batchsize(backend, N)
seeds = [basis(backend, x, ind) for ind in CartesianIndices(x)]
batched_seeds =
Batch.([
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
a in 1:div(N, B, RoundUp)
])
pushforward_batched_extras = prepare_pushforward_batched(
f_or_f!y..., backend, x, Batch(ntuple(Returns(seeds[1]), Val(B)))
f_or_f!y..., backend, x, batched_seeds[1]
)
D = typeof(seeds)
D = eltype(seeds)
E = typeof(pushforward_batched_extras)
Y = typeof(y)
return PushforwardJacobianExtras{B,D,E,Y}(seeds, pushforward_batched_extras, copy(y))
return PushforwardJacobianExtras{B,D,E,Y}(
batched_seeds, pushforward_batched_extras, copy(y)
)
end

function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardSlow) where {FY}
M = length(y)
B = pick_batchsize(backend, M)
seeds = [basis(backend, y, ind) for ind in CartesianIndices(y)]
batched_seeds =
Batch.([
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % M], Val(B)) for
a in 1:div(M, B, RoundUp)
])
pullback_batched_extras = prepare_pullback_batched(
f_or_f!y..., backend, x, Batch(ntuple(Returns(seeds[1]), Val(B)))
f_or_f!y..., backend, x, batched_seeds[1]
)
D = typeof(seeds)
D = eltype(seeds)
E = typeof(pullback_batched_extras)
Y = typeof(y)
return PullbackJacobianExtras{B,D,E,Y}(seeds, pullback_batched_extras, copy(y))
return PullbackJacobianExtras{B,D,E,Y}(batched_seeds, pullback_batched_extras, copy(y))
end

## One argument
Expand Down Expand Up @@ -197,27 +209,16 @@ end
function jacobian_aux(
f_or_f!y::FY, backend, x::AbstractArray, extras::PushforwardJacobianExtras{B}
) where {FY,B}
@compat (; seeds, pushforward_batched_extras, y_example) = extras
@compat (; batched_seeds, pushforward_batched_extras, y_example) = extras
N = length(x)

pushforward_batched_extras_same = prepare_pushforward_batched_same_point(
f_or_f!y...,
backend,
x,
Batch(ntuple(Returns(seeds[1]), Val(B))),
pushforward_batched_extras,
f_or_f!y..., backend, x, batched_seeds[1], pushforward_batched_extras
)

jac_blocks = map(1:div(N, B, RoundUp)) do a
dx_batch_elements = ntuple(Val(B)) do b
seeds[1 + ((a - 1) * B + (b - 1)) % N]
end
jac_blocks = map(eachindex(batched_seeds)) do a
dy_batch = pushforward_batched(
f_or_f!y...,
backend,
x,
Batch(dx_batch_elements),
pushforward_batched_extras_same,
f_or_f!y..., backend, x, batched_seeds[a], pushforward_batched_extras_same
)
stack(vec, dy_batch.elements; dims=2)
end
Expand All @@ -232,27 +233,16 @@ end
function jacobian_aux(
f_or_f!y::FY, backend, x::AbstractArray, extras::PullbackJacobianExtras{B}
) where {FY,B}
@compat (; seeds, pullback_batched_extras, y_example) = extras
@compat (; batched_seeds, pullback_batched_extras, y_example) = extras
M = length(y_example)

pullback_batched_extras_same = prepare_pullback_batched_same_point(
f_or_f!y...,
backend,
x,
Batch(ntuple(Returns(seeds[1]), Val(B))),
extras.pullback_batched_extras,
f_or_f!y..., backend, x, batched_seeds[1], extras.pullback_batched_extras
)

jac_blocks = map(1:div(M, B, RoundUp)) do a
dy_batch_elements = ntuple(Val(B)) do b
seeds[1 + ((a - 1) * B + (b - 1)) % M]
end
jac_blocks = map(eachindex(batched_seeds)) do a
dx_batch = pullback_batched(
f_or_f!y...,
backend,
x,
Batch(dy_batch_elements),
pullback_batched_extras_same,
f_or_f!y..., backend, x, batched_seeds[a], pullback_batched_extras_same
)
stack(vec, dx_batch.elements; dims=1)
end
Expand All @@ -271,21 +261,14 @@ function jacobian_aux!(
x::AbstractArray,
extras::PushforwardJacobianExtras{B},
) where {FY,B}
@compat (; seeds, pushforward_batched_extras, y_example) = extras
@compat (; batched_seeds, pushforward_batched_extras, y_example) = extras
N = length(x)

pushforward_batched_extras_same = prepare_pushforward_batched_same_point(
f_or_f!y...,
backend,
x,
Batch(ntuple(Returns(seeds[1]), Val(B))),
pushforward_batched_extras,
f_or_f!y..., backend, x, batched_seeds[1], pushforward_batched_extras
)

for a in 1:div(N, B, RoundUp)
dx_batch_elements = ntuple(Val(B)) do b
seeds[1 + ((a - 1) * B + (b - 1)) % N]
end
for a in eachindex(batched_seeds)
dy_batch_elements = ntuple(Val(B)) do b
reshape(view(jac, :, 1 + ((a - 1) * B + (b - 1)) % N), size(y_example))
end
Expand All @@ -294,7 +277,7 @@ function jacobian_aux!(
Batch(dy_batch_elements),
backend,
x,
Batch(dx_batch_elements),
batched_seeds[a],
pushforward_batched_extras_same,
)
end
Expand All @@ -309,21 +292,14 @@ function jacobian_aux!(
x::AbstractArray,
extras::PullbackJacobianExtras{B},
) where {FY,B}
@compat (; seeds, pullback_batched_extras, y_example) = extras
@compat (; batched_seeds, pullback_batched_extras, y_example) = extras
M = length(y_example)

pullback_batched_extras_same = prepare_pullback_batched_same_point(
f_or_f!y...,
backend,
x,
Batch(ntuple(Returns(seeds[1]), Val(B))),
extras.pullback_batched_extras,
f_or_f!y..., backend, x, batched_seeds[1], extras.pullback_batched_extras
)

for a in 1:div(M, B, RoundUp)
dy_batch_elements = ntuple(Val(B)) do b
seeds[1 + ((a - 1) * B + (b - 1)) % M]
end
for a in eachindex(batched_seeds)
dx_batch_elements = ntuple(Val(B)) do b
reshape(view(jac, 1 + ((a - 1) * B + (b - 1)) % M, :), size(x))
end
Expand All @@ -332,7 +308,7 @@ function jacobian_aux!(
Batch(dx_batch_elements),
backend,
x,
Batch(dy_batch_elements),
batched_seeds[a],
pullback_batched_extras_same,
)
end
Expand Down
Loading

0 comments on commit 8dc755b

Please sign in to comment.