Skip to content

Commit

Permalink
Remove custom printing (#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jun 26, 2024
1 parent bb0d874 commit feb17b0
Show file tree
Hide file tree
Showing 12 changed files with 72 additions and 72 deletions.
40 changes: 19 additions & 21 deletions DifferentiationInterface/docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ We support all dense backend choices from [ADTypes.jl](https://github.com/SciML/

```@setup backends
using DifferentiationInterface
using DifferentiationInterface: backend_str
import Markdown
import Diffractor
Expand All @@ -25,21 +24,21 @@ import Tapir
import Tracker
import Zygote
const backend_examples = (
"AutoDiffractor()",
"AutoEnzyme(; mode=Enzyme.Forward)",
"AutoEnzyme(; mode=Enzyme.Reverse)",
"AutoFastDifferentiation()",
"AutoFiniteDiff()",
"AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1))",
"AutoForwardDiff()",
"AutoPolyesterForwardDiff(; chunksize=1)",
"AutoReverseDiff()",
"AutoSymbolics()",
"AutoTapir(; safe_mode=false)",
"AutoTracker()",
"AutoZygote()",
)
backend_examples = [
AutoDiffractor(),
AutoEnzyme(; mode=Enzyme.Forward),
AutoEnzyme(; mode=Enzyme.Reverse),
AutoFastDifferentiation(),
AutoFiniteDiff(),
AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)),
AutoForwardDiff(),
AutoPolyesterForwardDiff(; chunksize=1),
AutoReverseDiff(),
AutoSymbolics(),
AutoTapir(; safe_mode=false),
AutoTracker(),
AutoZygote(),
]
checkmark(x::Bool) = x ? '✅' : '❌'
unicode_check_available(backend) = checkmark(check_available(backend))
Expand All @@ -49,12 +48,11 @@ unicode_check_twoarg(backend) = checkmark(check_twoarg(backend))
io = IOBuffer()
# Table header
println(io, "| Backend | Availability | Two-argument functions | Hessian support | Example |")
println(io, "|:--------|:------------:|:----------------------:|:---------------:|:--------|")
println(io, "| Backend | Availability | Two-argument functions | Hessian support |")
println(io, "|:--------|:------------:|:----------------------:|:---------------:|")
for example in backend_examples
b = eval(Meta.parse(example)) # backend
join(io, [backend_str(b), unicode_check_available(b), unicode_check_twoarg(b), unicode_check_hessian(b), "`$example`"], '|')
for b in backend_examples
join(io, [string(b), unicode_check_available(b), unicode_check_twoarg(b), unicode_check_hessian(b)], '|')
println(io, '|' )
end
backend_table = Markdown.parse(String(take!(io)))
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/docs/src/overloads.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The cells can have three values:
```@setup overloads
using ADTypes: AbstractADType
using DifferentiationInterface
using DifferentiationInterface: backend_str, twoarg_support, TwoArgSupported
using DifferentiationInterface: twoarg_support, TwoArgSupported
using Markdown: Markdown
using Diffractor: Diffractor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ end

ADTypes.mode(backend::AutoDeferredEnzyme) = ADTypes.mode(AutoEnzyme(backend.mode))

DI.backend_package_name(::AutoDeferredEnzyme) = "DeferredEnzyme"

DI.nested(backend::AutoEnzyme) = AutoDeferredEnzyme(backend.mode)

const AnyAutoEnzyme{M} = Union{AutoEnzyme{M},AutoDeferredEnzyme{M}}
Expand Down
4 changes: 3 additions & 1 deletion DifferentiationInterface/src/misc/differentiate_with.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,7 @@ Call the underlying function `dw.f` of a [`DifferentiateWith`](@ref) wrapper.

function Base.show(io::IO, dw::DifferentiateWith)
@compat (; f, backend) = dw
return print(io, "$f differentiated with $(backend_str(backend))")
return print(
io, DifferentiateWith, "(", repr(f; context=io), ",", repr(backend; context=io), ")"
)
end
10 changes: 9 additions & 1 deletion DifferentiationInterface/src/misc/sparsity_detector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,15 @@ end

function Base.show(io::IO, detector::DenseSparsityDetector{method}) where {method}
@compat (; backend, atol) = detector
return print(io, "DenseSparsityDetector{:$method}($backend; atol=$atol)")
return print(
io,
DenseSparsityDetector,
"(",
repr(backend; context=io),
"; atol=$atol, method=",
repr(method; context=io),
")",
)
end

function DenseSparsityDetector(
Expand Down
10 changes: 9 additions & 1 deletion DifferentiationInterface/src/second_order/second_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@ struct SecondOrder{ADO<:AbstractADType,ADI<:AbstractADType} <: AbstractADType
end

function Base.show(io::IO, backend::SecondOrder)
return print(io, "SecondOrder($(outer(backend)) / $(inner(backend)))")
return print(
io,
SecondOrder,
"(",
repr(outer(backend); context=io),
", ",
repr(inner(backend); context=io),
")",
)
end

"""
Expand Down
6 changes: 3 additions & 3 deletions DifferentiationInterface/src/utils/exceptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ struct MissingBackendError <: Exception
end

function Base.showerror(io::IO, e::MissingBackendError)
println(io, "failed to use $(backend_str(e.backend)) backend.")
println(io, "MissingBackendError: Failed to use $(e.backend).")
if !check_available(e.backend)
print(
io,
"""Backend package is not loaded. To fix, run
"""Backend package is probably not loaded. To fix this, try to run
import $(backend_package_name(e.backend))
import $(package_name(e.backend))
""",
)
else
Expand Down
43 changes: 7 additions & 36 deletions DifferentiationInterface/src/utils/printing.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,11 @@
backend_package_name(b::AbstractADType) = strip(string(b), ['(', ')'])
backend_package_name(b::AutoSparse) = backend_package_name(dense_ad(b))

backend_package_name(::AutoChainRules) = "ChainRules"
backend_package_name(::AutoDiffractor) = "Diffractor"
backend_package_name(::AutoEnzyme) = "Enzyme"
backend_package_name(::AutoFastDifferentiation) = "FastDifferentiation"
backend_package_name(::AutoFiniteDiff) = "FiniteDiff"
backend_package_name(::AutoFiniteDifferences) = "FiniteDifferences"
backend_package_name(::AutoForwardDiff) = "ForwardDiff"
backend_package_name(::AutoPolyesterForwardDiff) = "PolyesterForwardDiff"
backend_package_name(::AutoSymbolics) = "Symbolics"
backend_package_name(::AutoTapir) = "Tapir"
backend_package_name(::AutoTracker) = "Tracker"
backend_package_name(::AutoZygote) = "Zygote"
backend_package_name(::AutoReverseDiff) = "ReverseDiff"

backend_package_name(::AF) where {AF<:AutoForwardFromPrimitive} = string(AF)
backend_package_name(::AR) where {AR<:AutoReverseFromPrimitive} = string(AR)

function backend_str(backend::AbstractADType)
bs = backend_package_name(backend)
if mode(backend) isa ForwardMode
return "$bs (forward)"
elseif mode(backend) isa ReverseMode
return "$bs (reverse)"
elseif mode(backend) isa SymbolicMode
return "$bs (symbolic)"
elseif mode(backend) isa ForwardOrReverseMode
return "$bs (forward or reverse)"
function package_name(b::AbstractADType)
s = string(b)
k = findfirst('(', s)
if isnothing(k)
throw(ArgumentError("Cannot parse backend into package"))
else
error("Unknown mode")
return s[5:(k - 1)]
end
end

backend_str(backend::AutoSparse) = "Sparse $(backend_str(dense_ad(backend)))"

function backend_str(backend::SecondOrder)
return "$(backend_str(outer(backend))) / $(backend_str(inner(backend)))"
end
package_name(b::AutoSparse) = package_name(dense_ad(b))
18 changes: 18 additions & 0 deletions DifferentiationInterface/test/Internals/display.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using ADTypes
using DifferentiationInterface
using Test

backend = SecondOrder(AutoForwardDiff(), AutoZygote())
@test startswith(string(backend), "SecondOrder(")
@test endswith(string(backend), ")")

detector = DenseSparsityDetector(AutoForwardDiff(); atol=1e-23)
@test startswith(string(detector), "DenseSparsityDetector(")
@test endswith(string(detector), ")")

diffwith = DifferentiateWith(exp, AutoForwardDiff())
@test startswith(string(diffwith), "DifferentiateWith(")
@test endswith(string(diffwith), ")")

@test DifferentiationInterface.package_name(AutoForwardDiff()) == "ForwardDiff"
@test DifferentiationInterface.package_name(AutoZygote()) == "Zygote"
1 change: 0 additions & 1 deletion DifferentiationInterface/test/Internals/second_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@ using Test
backend = SecondOrder(AutoForwardDiff(), AutoZygote())

@test ADTypes.mode(backend) isa ADTypes.ForwardMode
@test startswith(string(backend), "SecondOrder")
@test DifferentiationInterface.outer(backend) isa AutoForwardDiff
@test DifferentiationInterface.inner(backend) isa AutoZygote
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ using DataFrames: DataFrame
using DifferentiationInterface
using DifferentiationInterface:
Batch,
backend_str,
inner,
maybe_inner,
maybe_dense_ad,
Expand Down
7 changes: 3 additions & 4 deletions DifferentiationInterfaceTest/src/test_differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ function test_differentiation(
prog = ProgressUnknown(; desc="$title", spinner=true, enabled=logging)

@testset verbose = true "$title" begin
@testset verbose = detailed "$(backend_str(backend))" for (i, backend) in
enumerate(backends)
@testset verbose = detailed "$backend" for (i, backend) in enumerate(backends)
filtered_scenarios = filter(s -> compatible(backend, s), scenarios)
grouped_scenarios = group_by_operator(filtered_scenarios)
@testset verbose = detailed "$op" for (j, (op, op_group)) in
Expand All @@ -88,7 +87,7 @@ function test_differentiation(
next!(
prog;
showvalues=[
(:backend, "$(backend_str(backend)) - $i/$(length(backends))"),
(:backend, "$backend - $i/$(length(backends))"),
(:scenario_type, "$op - $j/$(length(grouped_scenarios))"),
(:scenario, "$k/$(length(op_group))"),
(:arguments, nb_args(scen)),
Expand Down Expand Up @@ -177,7 +176,7 @@ function benchmark_differentiation(
next!(
prog;
showvalues=[
(:backend, "$(backend_str(backend)) - $i/$(length(backends))"),
(:backend, "$backend - $i/$(length(backends))"),
(:scenario_type, "$op - $j/$(length(grouped_scenarios))"),
(:scenario, "$k/$(length(op_group))"),
(:arguments, nb_args(scen)),
Expand Down

0 comments on commit feb17b0

Please sign in to comment.