Skip to content

Commit

Permalink
Improve type stability of Jacobians and Hessian, fix test scenarios (#…
Browse files Browse the repository at this point in the history
…337)

* Ensure type stability of test scenarios in 1.11

* Preallocate results

* Typo

* Fix GPU scenario

* Remporarily disable tests on 1.11

* I said disable them
  • Loading branch information
gdalle committed Jun 26, 2024
1 parent 83de009 commit ff529cb
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 93 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
version:
- '1'
- '1.6'
- '~1.11.0-0'
# - '~1.11.0-0'
group:
- Formalities
- Internals
Expand Down Expand Up @@ -118,7 +118,7 @@ jobs:
version:
- '1'
- '1.6'
- '~1.11.0-0'
# - '~1.11.0-0'
group:
- Formalities
- Zero
Expand Down
72 changes: 43 additions & 29 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,18 @@ abstract type JacobianExtras <: Extras end

struct NoJacobianExtras <: JacobianExtras end

struct PushforwardJacobianExtras{B,D,E<:PushforwardExtras,Y} <: JacobianExtras
struct PushforwardJacobianExtras{B,D,R,E<:PushforwardExtras} <: JacobianExtras
batched_seeds::Vector{Batch{B,D}}
batched_results::Vector{Batch{B,R}}
pushforward_batched_extras::E
y_example::Y
N::Int
end

struct PullbackJacobianExtras{B,D,E<:PullbackExtras,Y} <: JacobianExtras
struct PullbackJacobianExtras{B,D,R,E<:PullbackExtras} <: JacobianExtras
batched_seeds::Vector{Batch{B,D}}
batched_results::Vector{Batch{B,R}}
pullback_batched_extras::E
y_example::Y
M::Int
end

function prepare_jacobian(f::F, backend::AbstractADType, x) where {F}
Expand All @@ -85,14 +87,15 @@ function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardFast) wh
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
a in 1:div(N, B, RoundUp)
])
batched_results = Batch.([ntuple(b -> similar(y), Val(B)) for _ in batched_seeds])
pushforward_batched_extras = prepare_pushforward_batched(
f_or_f!y..., backend, x, batched_seeds[1]
)
D = eltype(seeds)
D = eltype(batched_seeds[1])
R = eltype(batched_results[1])
E = typeof(pushforward_batched_extras)
Y = typeof(y)
return PushforwardJacobianExtras{B,D,E,Y}(
batched_seeds, pushforward_batched_extras, copy(y)
return PushforwardJacobianExtras{B,D,R,E}(
batched_seeds, batched_results, pushforward_batched_extras, N
)
end

Expand All @@ -105,13 +108,16 @@ function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardSlow) wh
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % M], Val(B)) for
a in 1:div(M, B, RoundUp)
])
batched_results = Batch.([ntuple(b -> similar(x), Val(B)) for _ in batched_seeds])
pullback_batched_extras = prepare_pullback_batched(
f_or_f!y..., backend, x, batched_seeds[1]
)
D = eltype(seeds)
D = eltype(batched_seeds[1])
R = eltype(batched_results[1])
E = typeof(pullback_batched_extras)
Y = typeof(y)
return PullbackJacobianExtras{B,D,E,Y}(batched_seeds, pullback_batched_extras, copy(y))
return PullbackJacobianExtras{B,D,R,E}(
batched_seeds, batched_results, pullback_batched_extras, M
)
end

## One argument
Expand Down Expand Up @@ -209,8 +215,7 @@ end
function jacobian_aux(
f_or_f!y::FY, backend, x::AbstractArray, extras::PushforwardJacobianExtras{B}
) where {FY,B}
@compat (; batched_seeds, pushforward_batched_extras, y_example) = extras
N = length(x)
@compat (; batched_seeds, pushforward_batched_extras, N) = extras

pushforward_batched_extras_same = prepare_pushforward_batched_same_point(
f_or_f!y..., backend, x, batched_seeds[1], pushforward_batched_extras
Expand All @@ -233,8 +238,7 @@ end
function jacobian_aux(
f_or_f!y::FY, backend, x::AbstractArray, extras::PullbackJacobianExtras{B}
) where {FY,B}
@compat (; batched_seeds, pullback_batched_extras, y_example) = extras
M = length(y_example)
@compat (; batched_seeds, pullback_batched_extras, M) = extras

pullback_batched_extras_same = prepare_pullback_batched_same_point(
f_or_f!y..., backend, x, batched_seeds[1], extras.pullback_batched_extras
Expand All @@ -261,27 +265,32 @@ function jacobian_aux!(
x::AbstractArray,
extras::PushforwardJacobianExtras{B},
) where {FY,B}
@compat (; batched_seeds, pushforward_batched_extras, y_example) = extras
N = length(x)
@compat (; batched_seeds, batched_results, pushforward_batched_extras, N) = extras

pushforward_batched_extras_same = prepare_pushforward_batched_same_point(
f_or_f!y..., backend, x, batched_seeds[1], pushforward_batched_extras
)

for a in eachindex(batched_seeds)
dy_batch_elements = ntuple(Val(B)) do b
reshape(view(jac, :, 1 + ((a - 1) * B + (b - 1)) % N), size(y_example))
end
for a in eachindex(batched_seeds, batched_results)
pushforward_batched!(
f_or_f!y...,
Batch(dy_batch_elements),
batched_results[a],
backend,
x,
batched_seeds[a],
pushforward_batched_extras_same,
)
end

for a in eachindex(batched_results)
for b in eachindex(batched_results[a].elements)
copyto!(
view(jac, :, 1 + ((a - 1) * B + (b - 1)) % N),
vec(batched_results[a].elements[b]),
)
end
end

return jac
end

Expand All @@ -292,26 +301,31 @@ function jacobian_aux!(
x::AbstractArray,
extras::PullbackJacobianExtras{B},
) where {FY,B}
@compat (; batched_seeds, pullback_batched_extras, y_example) = extras
M = length(y_example)
@compat (; batched_seeds, batched_results, pullback_batched_extras, M) = extras

pullback_batched_extras_same = prepare_pullback_batched_same_point(
f_or_f!y..., backend, x, batched_seeds[1], extras.pullback_batched_extras
)

for a in eachindex(batched_seeds)
dx_batch_elements = ntuple(Val(B)) do b
reshape(view(jac, 1 + ((a - 1) * B + (b - 1)) % M, :), size(x))
end
for a in eachindex(batched_seeds, batched_results)
pullback_batched!(
f_or_f!y...,
Batch(dx_batch_elements),
batched_results[a],
backend,
x,
batched_seeds[a],
pullback_batched_extras_same,
)
end

for a in eachindex(batched_results)
for b in eachindex(batched_results[a].elements)
copyto!(
view(jac, 1 + ((a - 1) * B + (b - 1)) % M, :),
vec(batched_results[a].elements[b]),
)
end
end

return jac
end
39 changes: 21 additions & 18 deletions DifferentiationInterface/src/second_order/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ abstract type HessianExtras <: Extras end

struct NoHessianExtras <: HessianExtras end

struct HVPGradientHessianExtras{B,D,E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras
struct HVPGradientHessianExtras{B,D,R,E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras
batched_seeds::Vector{Batch{B,D}}
batched_results::Vector{Batch{B,R}}
hvp_batched_extras::E2
gradient_extras::E1
N::Int
end

function prepare_hessian(f::F, backend::AbstractADType, x) where {F}
Expand All @@ -64,12 +66,14 @@ function prepare_hessian(f::F, backend::AbstractADType, x) where {F}
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
a in 1:div(N, B, RoundUp)
])
batched_results = Batch.([ntuple(b -> similar(x), Val(B)) for _ in batched_seeds])
hvp_batched_extras = prepare_hvp_batched(f, backend, x, batched_seeds[1])
gradient_extras = prepare_gradient(f, maybe_inner(backend), x)
D = eltype(seeds)
D = eltype(batched_seeds[1])
R = eltype(batched_results[1])
E2, E1 = typeof(hvp_batched_extras), typeof(gradient_extras)
return HVPGradientHessianExtras{B,D,E2,E1}(
batched_seeds, hvp_batched_extras, gradient_extras
return HVPGradientHessianExtras{B,D,R,E2,E1}(
batched_seeds, batched_results, hvp_batched_extras, gradient_extras, N
)
end

Expand Down Expand Up @@ -100,8 +104,7 @@ end
function hessian(
f::F, backend::AbstractADType, x, extras::HVPGradientHessianExtras{B}
) where {F,B}
@compat (; batched_seeds, hvp_batched_extras) = extras
N = length(x)
@compat (; batched_seeds, hvp_batched_extras, N) = extras

hvp_batched_extras_same = prepare_hvp_batched_same_point(
f, backend, x, batched_seeds[1], hvp_batched_extras
Expand All @@ -122,27 +125,27 @@ end
function hessian!(
f::F, hess, backend::AbstractADType, x, extras::HVPGradientHessianExtras{B}
) where {F,B}
@compat (; batched_seeds, hvp_batched_extras) = extras
N = length(x)
@compat (; batched_seeds, batched_results, hvp_batched_extras, N) = extras

hvp_batched_extras_same = prepare_hvp_batched_same_point(
f, backend, x, batched_seeds[1], hvp_batched_extras
)

for a in eachindex(batched_seeds)
dg_batch_elements = ntuple(Val(B)) do b
reshape(view(hess, :, 1 + ((a - 1) * B + (b - 1)) % N), size(x))
end
for a in eachindex(batched_seeds, batched_results)
hvp_batched!(
f,
Batch(dg_batch_elements),
backend,
x,
batched_seeds[a],
hvp_batched_extras_same,
f, batched_results[a], backend, x, batched_seeds[a], hvp_batched_extras_same
)
end

for a in eachindex(batched_results)
for b in eachindex(batched_results[a].elements)
copyto!(
view(hess, :, 1 + ((a - 1) * B + (b - 1)) % N),
vec(batched_results[a].elements[b]),
)
end
end

return hess
end

Expand Down
38 changes: 28 additions & 10 deletions DifferentiationInterface/src/sparse/hessian.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
struct SparseHessianExtras{
B,S<:AbstractMatrix{Bool},C<:AbstractMatrix{<:Real},D,E2<:HVPExtras,E1<:GradientExtras
B,S<:AbstractMatrix{Bool},C<:AbstractMatrix{<:Real},D,R,E2<:HVPExtras,E1<:GradientExtras
} <: HessianExtras
sparsity::S
colors::Vector{Int}
groups::Vector{Vector{Int}}
compressed::C
batched_seeds::Vector{Batch{B,D}}
batched_results::Vector{Batch{B,R}}
hvp_batched_extras::E2
gradient_extras::E1
end
Expand All @@ -16,16 +17,18 @@ function SparseHessianExtras{B}(;
groups,
compressed::C,
batched_seeds::Vector{Batch{B,D}},
batched_results::Vector{Batch{B,R}},
hvp_batched_extras::E2,
gradient_extras::E1,
) where {B,S,C,D,E2,E1}
) where {B,S,C,D,R,E2,E1}
@assert size(sparsity, 1) == size(sparsity, 2) == size(compressed, 1) == length(colors)
return SparseHessianExtras{B,S,C,D,E2,E1}(
return SparseHessianExtras{B,S,C,D,R,E2,E1}(
sparsity,
colors,
groups,
compressed,
batched_seeds,
batched_results,
hvp_batched_extras,
gradient_extras,
)
Expand All @@ -48,6 +51,7 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B)) for
a in 1:div(Ng, B, RoundUp)
])
batched_results = Batch.([ntuple(b -> similar(x), Val(B)) for _ in batched_seeds])
hvp_batched_extras = prepare_hvp_batched(f, dense_backend, x, batched_seeds[1])
gradient_extras = prepare_gradient(f, maybe_inner(dense_backend), x)
return SparseHessianExtras{B}(;
Expand All @@ -56,6 +60,7 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
groups,
compressed,
batched_seeds,
batched_results,
hvp_batched_extras,
gradient_extras,
)
Expand Down Expand Up @@ -86,29 +91,42 @@ end
function hessian!(
f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras{B}
) where {F,B}
@compat (; sparsity, compressed, colors, groups, batched_seeds, hvp_batched_extras) =
extras
@compat (;
sparsity,
compressed,
colors,
groups,
batched_seeds,
batched_results,
hvp_batched_extras,
) = extras
dense_backend = dense_ad(backend)
Ng = length(groups)

hvp_batched_extras_same = prepare_hvp_batched_same_point(
f, dense_backend, x, batched_seeds[1], hvp_batched_extras
)

for a in 1:div(Ng, B, RoundUp)
dg_batch_elements = ntuple(Val(B)) do b
reshape(view(compressed, :, 1 + ((a - 1) * B + (b - 1)) % Ng), size(x))
end
for a in eachindex(batched_seeds, batched_results)
hvp_batched!(
f,
Batch(dg_batch_elements),
batched_results[a],
dense_backend,
x,
batched_seeds[a],
hvp_batched_extras_same,
)
end

for a in eachindex(batched_results)
for b in eachindex(batched_results[a].elements)
copyto!(
view(compressed, :, 1 + ((a - 1) * B + (b - 1)) % Ng),
vec(batched_results[a].elements[b]),
)
end
end

decompress_symmetric!(hess, sparsity, compressed, colors)
return hess
end
Expand Down
Loading

0 comments on commit ff529cb

Please sign in to comment.