Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Trapezoidal rule #173

Merged
merged 18 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,36 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
IntegralsFastGaussQuadratureExt = "FastGaussQuadrature"
IntegralsForwardDiffExt = "ForwardDiff"
IntegralsZygoteExt = ["Zygote", "ChainRulesCore"]

[compat]
ChainRulesCore = "0.10.7, 1"
CommonSolve = "0.2"
Distributions = "0.23, 0.24, 0.25"
FastGaussQuadrature = "0.5"
ForwardDiff = "0.10"
HCubature = "1.4"
MonteCarloIntegration = "0.0.1, 0.0.2, 0.0.3"
QuadGK = "2.5"
Reexport = "0.2, 1.0"
Requires = "1"
SciMLBase = "1.70"
SciMLBase = "1.98"
Zygote = "0.4.22, 0.5, 0.6"
julia = "1.6"
FastGaussQuadrature = "0.5"

[extensions]
IntegralsForwardDiffExt = "ForwardDiff"
IntegralsZygoteExt = ["Zygote", "ChainRulesCore"]
IntegralsFastGaussQuadratureExt = "FastGaussQuadrature"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -44,13 +51,6 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"

[targets]
test = ["SciMLSensitivity", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature"]

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pages = ["index.md",
"Tutorials" => Any["tutorials/numerical_integrals.md",
"tutorials/differentiating_integrals.md"],
"Basics" => Any["basics/IntegralProblem.md",
"basics/SampledIntegralProblem.md",
"basics/solve.md",
"basics/FAQ.md"],
"Solvers" => Any["solvers/IntegralSolvers.md"],
Expand Down
73 changes: 73 additions & 0 deletions docs/src/basics/SampledIntegralProblem.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Integrating pre-sampled data

In some cases, instead of a function that acts as integrand,
one only possesses a list of data points `y` at a set of sampling
locations `x`, that must be integrated. This package contains functionality
for doing that.

## Example

Say, by some means we have generated a dataset `x` and `y`:
```example 1
using Integrals # hide
f = x -> x^2
x = range(0, 1, length=20)
y = f.(x)
```

Now, we can integrate this data set as follows:

```example 1
problem = SampledIntegralProblem(y, x)
method = TrapezoidalRule()
solve(problem, method)
```

The exact aswer is of course \$ 1/3 \$.

## Details

### Non-equidistant grids

If the sampling points `x` are provided as an `AbstractRange`
(constructed with the `range` function for example), faster methods are used that take advantage of
the fact that the points are equidistantly spaced. Otherwise, general methods are used for
non-uniform grids.

Example:

```example 2
using Integrals # hide
f = x -> x^7
x = [0.0; sort(rand(1000)); 1.0]
y = f.(x)
problem = SampledIntegralProblem(y, x)
method = TrapezoidalRule()
solve(problem, method)
```

### Evaluating multiple integrals at once

If the provided data set `y` is a multidimensional array, the integrals are evaluated across only one
of its axes. For performance reasons, the last axis of the array `y` is chosen by default, but this can be modified with the `dim`
keyword argument to the problem definition.

```example 3
using Integrals # hide
f1 = x -> x^2
f2 = x -> x^3
f3 = x -> x^4
x = range(0, 1, length=20)
y = [f1.(x) f2.(x) f3.(x)]
problem = SampledIntegralProblem(y, x; dim=1)
method = TrapezoidalRule()
solve(problem, method)
```

### Supported methods

Right now, only the `TrapezoidalRule` is supported, [see wikipedia](https://en.wikipedia.org/wiki/Trapezoidal_rule).

```@docs
TrapezoidalRule
```
47 changes: 47 additions & 0 deletions docs/src/tutorials/caching_interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,50 @@ Note that the types of these variables is not allowed to change.
If it is necessary to change the integrand `f` instead of defining a new
`IntegralProblem`, consider using
[FunctionWrappers.jl](https://github.com/yuyichao/FunctionWrappers.jl).

## Caching for sampled integral problems

For sampled integral problems, it is possible to cache the weights and reuse
them for multiple data sets.
```@example cache2
using Integrals

x = 0.0:0.1:1.0
y = sin.(x)

prob = SampledIntegralProblem(y, x)
alg = TrapezoidalRule()

cache = init(prob, alg)
sol1 = solve!(cache)
```

```@example cache2
cache.y = cos.(x) # use .= to update in-place
sol2 = solve!(cache)
```
If the grid is modified, the weights are recomputed.
```@example cache2
cache.x = 0.0:0.2:2.0
cache.y = sin.(cache.x)
sol3 = solve!(cache)
```

For multi-dimensional datasets, the integration dimension can also be changed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh that's pretty cool.

```@example cache3
using Integrals

x = 0.0:0.1:1.0
y = sin.(x) .* cos.(x')

prob = SampledIntegralProblem(y, x)
alg = TrapezoidalRule()

cache = init(prob, alg)
sol1 = solve!(cache)
```

```@example cache3
cache.dim = 1
sol2 = solve!(cache)
```
4 changes: 3 additions & 1 deletion src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ include("init.jl")
include("algorithms.jl")
include("infinity_handling.jl")
include("quadrules.jl")
include("sampled.jl")
include("trapezoidal.jl")

abstract type QuadSensitivityAlg end
struct ReCallVJP{V}
Expand Down Expand Up @@ -148,5 +150,5 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p;
SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success)
end

export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre, QuadratureRule
export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre, QuadratureRule, TrapezoidalRule
end # module
23 changes: 22 additions & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,28 @@ function GaussLegendre(; n = 250, subintervals = 1, nodes = nothing, weights = n
end

"""
QuadratureRule(q; n=250)
TrapezoidalRule

Struct for evaluating an integral via the trapezoidal rule.


Example with sampled data:

```
using Integrals
f = x -> x^2
x = range(0, 1, length=20)
y = f.(x)
problem = SampledIntegralProblem(y, x)
method = TrapezoidalRul()
solve(problem, method)
```
"""
struct TrapezoidalRule <: SciMLBase.AbstractIntegralAlgorithm
end

"""
QuadratureRule(q; n=250)

Algorithm to construct and evaluate a quadrature rule `q` of `n` points computed from the
inputs as `x, w = q(n)`. It assumes the nodes and weights are for the standard interval
Expand Down
62 changes: 62 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,65 @@ end
function __solvebp_call(cache::IntegralCache, args...; kwargs...)
__solvebp_call(build_problem(cache), args...; kwargs...)
end


mutable struct SampledIntegralCache{Y, X, D, PK, A, K, Tc}
y::Y
x::X
dim::D
prob_kwargs::PK
alg::A
kwargs::K
isfresh::Bool # state of whether weights have been calculated
cacheval::Tc # store alg weights here
end

function Base.setproperty!(cache::SampledIntegralCache, name::Symbol, x)
if name === :x
setfield!(cache, :isfresh, true)
end
setfield!(cache, name, x)
end

function SciMLBase.init(prob::SampledIntegralProblem,
alg::SciMLBase.AbstractIntegralAlgorithm;
kwargs...)
NamedTuple(kwargs) == NamedTuple() || throw(ArgumentError("There are no keyword arguments allowed to `solve`"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's odd, why not? There are many keyword arguments that would go here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lxvm is there a specific reason why not, or did you just not expect it to ever be necessary?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation of the solver doesn't use any keyword arguments, so I didn't want an api with keywords. If future algorithms for sampled integral problems need them I would expect this to change, but there are no convergence criteria for this kind of problem, so I wasn't expecting any

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh just the sampled data methods. Okay, yeah for now might as well throw. We always try to throw if keyword arguments that are incorrect so this is good.


cacheval = init_cacheval(alg, prob)
isfresh = true

SampledIntegralCache(
prob.y,
prob.x,
prob.dim,
prob.kwargs,
alg,
kwargs,
isfresh,
cacheval)
end


"""
```julia
solve(prob::SampledIntegralProblem, alg::SciMLBase.AbstractIntegralAlgorithm; kwargs...)
```

## Keyword Arguments

There are no keyword arguments used to solve `SampledIntegralProblem`s
"""
function SciMLBase.solve(prob::SampledIntegralProblem,
alg::SciMLBase.AbstractIntegralAlgorithm;
kwargs...)
solve!(init(prob, alg; kwargs...))
end

function SciMLBase.solve!(cache::SampledIntegralCache)
__solvebp(cache, cache.alg; cache.kwargs...)
end

function build_problem(cache::SampledIntegralCache)
SampledIntegralProblem(cache.y, cache.x; dim = dimension(cache.dim), cache.prob_kwargs...)
end
71 changes: 71 additions & 0 deletions src/sampled.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
abstract type AbstractWeights end

# must have field `n` for length, and a field `h` for stepsize
abstract type UniformWeights <: AbstractWeights end
@inline Base.iterate(w::UniformWeights) = (0 == w.n) ? nothing : (w[1], 1)
@inline Base.iterate(w::UniformWeights, i) = (i == w.n) ? nothing : (w[i+1], i+1)
Base.length(w::UniformWeights) = w.n
Base.eltype(w::UniformWeights) = typeof(w.h)
Base.size(w::UniformWeights) = (length(w), )

Check warning on line 9 in src/sampled.jl

View check run for this annotation

Codecov / codecov/patch

src/sampled.jl#L7-L9

Added lines #L7 - L9 were not covered by tests

# must contain field `x` which are the sampling points
abstract type NonuniformWeights <: AbstractWeights end
@inline Base.iterate(w::NonuniformWeights) = (0 == length(w.x)) ? nothing : (w[firstindex(w.x)], firstindex(w.x))
@inline Base.iterate(w::NonuniformWeights, i) = (i == lastindex(w.x)) ? nothing : (w[i+1], i+1)
Base.length(w::NonuniformWeights) = length(w.x)
Base.eltype(w::NonuniformWeights) = eltype(w.x)
Base.size(w::NonuniformWeights) = (length(w), )

Check warning on line 17 in src/sampled.jl

View check run for this annotation

Codecov / codecov/patch

src/sampled.jl#L15-L17

Added lines #L15 - L17 were not covered by tests

_eachslice(data::AbstractArray; dims=ndims(data)) = eachslice(data; dims=dims)
_eachslice(data::AbstractArray{T, 1}; dims=ndims(data)) where T = data


# these can be removed when the Val(dim) is removed from SciMLBase
dimension(::Val{D}) where {D} = D

Check warning on line 24 in src/sampled.jl

View check run for this annotation

Codecov / codecov/patch

src/sampled.jl#L24

Added line #L24 was not covered by tests
dimension(D::Int) = D


function evalrule(data::AbstractArray, weights, dim)
fw = zip(_eachslice(data, dims=dim), weights)
next = iterate(fw)
next === nothing && throw(ArgumentError("No points to integrate"))
(f1, w1), state = next
out = w1 * f1
next = iterate(fw, state)
if isbits(out)
while next !== nothing
(fi, wi), state = next
out += wi * fi
next = iterate(fw, state)
end
else
while next !== nothing
(fi, wi), state = next
out .+= wi .* fi
next = iterate(fw, state)
end
end
return out
end


# can be reused for other sampled rules, which should implement find_weights(x, alg)

function init_cacheval(alg::SciMLBase.AbstractIntegralAlgorithm, prob::SampledIntegralProblem)
find_weights(prob.x, alg)
end

function __solvebp_call(cache::SampledIntegralCache, alg::SciMLBase.AbstractIntegralAlgorithm; kwargs...)
dim = dimension(cache.dim)
err = nothing
data = cache.y
grid = cache.x
if cache.isfresh
cache.cacheval = find_weights(grid, alg)
cache.isfresh = false
end
weights = cache.cacheval
I = evalrule(data, weights, dim)
prob = build_problem(cache)
return SciMLBase.build_solution(prob, alg, I, err, retcode = ReturnCode.Success)
end
Loading
Loading