Skip to content

Commit

Permalink
Fix DifferentiationInterfaceTest tutorial (#335)
Browse files Browse the repository at this point in the history
* Fix DIT tutorial

* Make Scenario public

* More stuff
  • Loading branch information
gdalle committed Jun 26, 2024
1 parent 8dc755b commit 415dc6c
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 20 deletions.
1 change: 0 additions & 1 deletion DifferentiationInterface/docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ makedocs(;
"Reference" => ["operators.md", "backends.md", "api.md"],
"Advanced" => ["dev_guide.md", "overloads.md"],
],
checkdocs=:exports,
plugins=[links],
)

Expand Down
1 change: 1 addition & 0 deletions DifferentiationInterfaceTest/docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Expand Down
1 change: 1 addition & 0 deletions DifferentiationInterfaceTest/docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using DifferentiationInterface
using DifferentiationInterfaceTest
using Documenter
using DocumenterInterLinks

using BenchmarkTools: BenchmarkTools
using DataFrames: DataFrames
Expand Down
1 change: 1 addition & 0 deletions DifferentiationInterfaceTest/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ static_scenarios
## Scenario types

```@docs
Scenario
PushforwardScenario
PullbackScenario
DerivativeScenario
Expand Down
22 changes: 5 additions & 17 deletions DifferentiationInterfaceTest/docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ We present a typical workflow with DifferentiationInterfaceTest.jl, building on
```@repl tuto
using DifferentiationInterface, DifferentiationInterfaceTest
import ForwardDiff, Enzyme
import Markdown, PrettyTables, Printf
```

## Introduction
Expand All @@ -31,7 +30,8 @@ Of course we know the true gradient mapping:
DifferentiationInterfaceTest.jl relies with so-called "scenarios", in which you encapsulate the information needed for your test:

- the function `f`
- the input `x` and output `y`
- the input `x` and output `y` of the function `f`
- the reference output of the operator (here `grad`)
- the number of arguments for `f` (either `1` or `2`)
- the behavior of the operator (either `:inplace` or `:outofplace`)

Expand All @@ -41,8 +41,8 @@ There is one scenario constructor per operator, and so here we will use [`Gradie
xv = rand(Float32, 3)
xm = rand(Float64, 3, 2)
scenarios = [
GradientScenario(f; x=xv, y=f(xv), nb_args=1, place=:inplace),
GradientScenario(f; x=xm, y=f(xm), nb_args=1, place=:inplace)
GradientScenario(f; x=xv, y=f(xv), grad=∇f(xv), nb_args=1, place=:inplace),
GradientScenario(f; x=xm, y=f(xm), grad=∇f(xm), nb_args=1, place=:inplace)
];
nothing # hide
```
Expand Down Expand Up @@ -73,16 +73,4 @@ This is made easy by the [`benchmark_differentiation`](@ref) function, whose syn
df = benchmark_differentiation(backends, scenarios);
```

The resulting object is `DataFrame` from [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl), whose columns correspond to the fields of [`DifferentiationBenchmarkDataRow`](@ref):
Here's what it looks like with all of its columns.

```@example tuto
table = PrettyTables.pretty_table(
String,
df;
backend=Val(:markdown),
header=names(df),
)
Markdown.parse(table)
```
The resulting object is a `DataFrame` from [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl), whose columns correspond to the fields of [`DifferentiationBenchmarkDataRow`](@ref):
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ include("tests/sparsity.jl")
include("tests/benchmark.jl")
include("test_differentiation.jl")

export Scenario
export PushforwardScenario,
PullbackScenario,
DerivativeScenario,
Expand Down
26 changes: 24 additions & 2 deletions DifferentiationInterfaceTest/src/scenarios/scenario.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ This generic type should never be used directly: use the specific constructor co
# Fields
$(TYPEDFIELDS)
Note that the `res1` and `res2` fields are given more meaningful names in the keyword arguments of each specialized constructor.
For example:
- the keyword `grad` of `GradientScenario` becomes `res1`
- the keyword `hess` of `HessianScenario` becomes `res2`, and the keyword `grad` becomes `res1`
"""
struct Scenario{op,args,pl,F,X,Y,D,R1,R2}
"function `f` (if `args==1`) or `f!` (if `args==2`) to apply"
Expand All @@ -35,9 +41,9 @@ struct Scenario{op,args,pl,F,X,Y,D,R1,R2}
y::Y
"seed for pushforward, pullback or HVP"
seed::D
"first-order result"
"first-order result of the operator"
res1::R1
"second-order result"
"second-order result of the operator (when it makes sense)"
res2::R2

function Scenario{op,args,pl}(
Expand Down Expand Up @@ -120,20 +126,26 @@ end

"""
$(SIGNATURES)
Construct a [`Scenario`](@ref) to test `pushforward` and its variants.
"""
function PushforwardScenario(f; x, y, dx, dy=nothing, nb_args, place=:inplace)
return Scenario{:pushforward,nb_args,place}(f; x, y, seed=dx, res1=dy, res2=nothing)
end

"""
$(SIGNATURES)
Construct a [`Scenario`](@ref) to test `pullback` and its variants.
"""
function PullbackScenario(f; x, y, dy, dx=nothing, nb_args, place=:inplace)
return Scenario{:pullback,nb_args,place}(f; x, y, seed=dy, res1=dx, res2=nothing)
end

"""
$(SIGNATURES)
Construct a [`Scenario`](@ref) to test `derivative` and its variants.
"""
function DerivativeScenario(f; x, y, der=nothing, nb_args, place=:inplace)
return Scenario{:derivative,nb_args,place}(
Expand All @@ -143,20 +155,26 @@ end

"""
$(SIGNATURES)
Construct a [`Scenario`](@ref) to test `gradient` and its variants.
"""
function GradientScenario(f; x, y, grad=nothing, nb_args, place=:inplace)
return Scenario{:gradient,nb_args,place}(f; x, y, seed=nothing, res1=grad, res2=nothing)
end

"""
$(SIGNATURES)
Construct a [`Scenario`](@ref) to test `jacobian` and its variants.
"""
function JacobianScenario(f; x, y, jac=nothing, nb_args, place=:inplace)
return Scenario{:jacobian,nb_args,place}(f; x, y, seed=nothing, res1=jac, res2=nothing)
end

"""
$(SIGNATURES)
Construct a [`Scenario`](@ref) to test `second_derivative` and its variants.
"""
function SecondDerivativeScenario(
f; x, y, der=nothing, der2=nothing, nb_args, place=:inplace
Expand All @@ -168,13 +186,17 @@ end

"""
$(SIGNATURES)
Construct a [`Scenario`](@ref) to test `hvp` and its variants.
"""
function HVPScenario(f; x, y, dx, grad=nothing, dg=nothing, nb_args, place=:inplace)
return Scenario{:hvp,nb_args,place}(f; x, y, seed=dx, res1=grad, res2=dg)
end

"""
$(SIGNATURES)
Construct a [`Scenario`](@ref) to test `hessian` and its variants.
"""
function HessianScenario(f; x, y, grad=nothing, hess=nothing, nb_args, place=:inplace)
return Scenario{:hessian,nb_args,place}(f; x, y, seed=nothing, res1=grad, res2=hess)
Expand Down

0 comments on commit 415dc6c

Please sign in to comment.