Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ForwardDiff Batch Mode Support #29

Merged
merged 5 commits into from
Sep 3, 2020
Merged

ForwardDiff Batch Mode Support #29

merged 5 commits into from
Sep 3, 2020

Conversation

agerlach
Copy link
Collaborator

@agerlach agerlach commented Jul 10, 2020

Corrects the output size for Zygote Batch mode Adjoint.

Comment on lines 414 to 418
dfdp = function (dx,x,p)
# dfdp = function (dx,x,p)
dfdp = function (x,p)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to oop b/c unsure how to copy result to dx w/o mutating. If using Buffer we need to allocate for the result anyway

dx = Zygote.Buffer(x)
prob.f(dx,x,p)
copy(dx)
_dx = Zygote.Buffer(x, prob.nout, size(x,2))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some of the quadrature methods prob.batch isn't adhered to. It looks like it serves more as a max batch number. Some methods "grow" the batch size. So, need to set the solution size accordingly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further more, Quadrature.jl runs a test calling prob.f([lb ub], p) to test if the solution of the integrand is a Vector. So if batch>0, the code will always try two point batch first.

@agerlach agerlach changed the title adjoint updates towards working batch mode. Zygote adjoint updates towards working batch mode. Jul 10, 2020
@ChrisRackauckas ChrisRackauckas changed the title Zygote adjoint updates towards working batch mode. [WIP] Zygote adjoint updates towards working batch mode. Jul 10, 2020
@agerlach
Copy link
Collaborator Author

ERROR: DimensionMismatch("tried to assign 1 elements to 2 destinations")
Stacktrace:
 [1] throw_setindex_mismatch(::Array{Float64,1}, ::Tuple{Int64}) at ./indices.jl:191
 [2] setindex_shape_check at ./indices.jl:242 [inlined]
 [3] setindex!(::Array{Float64,2}, ::Array{Float64,1}, ::Colon) at ./array.jl:860
 [4] (::Zygote.var"#1072#1073"{Zygote.Context,Zygote.Buffer{Float64,Array{Float64,2}}})(::Array{Float64,1}) at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/lib/buffer.jl:49
 [5] #2439#back at /Users/gerlacar/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [6] #959 at /Users/gerlacar/.julia/dev/Quadrature/src/Quadrature.jl:422 [inlined]
 [7] (::typeof(∂(λ)))(::Array{Float64,1}) at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [8] #37 at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:45 [inlined]
 [9] (::Quadrature.var"#958#966"{Array{Float64,1},QuadratureProblem{true,Array{Float64,1},DiffEqUncertainty.var"#852#861"{DiffEqUncertainty.var"#859#868",DiffEqUncertainty.var"#860#869",Int64,Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}},typeof(cost),ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(fiip),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},var"#u0_f#212",var"#p_f#213",Tuple{Tsit5,EnsembleThreads},Int64,Array{Bool,1},Array{Bool,1}},Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}}})(::Array{Float64,2}, ::Array{Float64,1}) at /Users/gerlacar/.julia/dev/Quadrature/src/Quadrature.jl:424
 [10] __solvebp_call(::QuadratureProblem{false,Array{Float64,1},Quadrature.var"#958#966"{Array{Float64,1},QuadratureProblem{true,Array{Float64,1},DiffEqUncertainty.var"#852#861"{DiffEqUncertainty.var"#859#868",DiffEqUncertainty.var"#860#869",Int64,Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}},typeof(cost),ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(fiip),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},var"#u0_f#212",var"#p_f#213",Tuple{Tsit5,EnsembleThreads},Int64,Array{Bool,1},Array{Bool,1}},Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}}},Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}}}, ::CubaCuhre, ::Quadrature.ReCallVJP{Quadrature.ZygoteVJP}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}; reltol::Float64, abstol::Float64, maxiters::Int64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /Users/gerlacar/.julia/dev/Quadrature/src/Quadrature.jl:346
 [11] (::Quadrature.var"#quadrature_adjoint#965"{Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}},QuadratureProblem{true,Array{Float64,1},DiffEqUncertainty.var"#852#861"{DiffEqUncertainty.var"#859#868",DiffEqUncertainty.var"#860#869",Int64,Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}},typeof(cost),ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(fiip),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},var"#u0_f#212",var"#p_f#213",Tuple{Tsit5,EnsembleThreads},Int64,Array{Bool,1},Array{Bool,1}},Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}},CubaCuhre,Quadrature.ReCallVJP{Quadrature.ZygoteVJP},Array{Float64,1},Array{Float64,1},Array{Float64,1},Tuple{}})(::Array{Float64,1}) at /Users/gerlacar/.julia/dev/Quadrature/src/Quadrature.jl:446
 [12] #3733#back at /Users/gerlacar/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
 [13] #175 at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182 [inlined]
 [14] (::Zygote.var"#347#back#177"{Zygote.var"#175#176"{Quadrature.var"#3733#back#974"{Quadrature.var"#quadrature_adjoint#965"{Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}},QuadratureProblem{true,Array{Float64,1},DiffEqUncertainty.var"#852#861"{DiffEqUncertainty.var"#859#868",DiffEqUncertainty.var"#860#869",Int64,Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}},typeof(cost),ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(fiip),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},var"#u0_f#212",var"#p_f#213",Tuple{Tsit5,EnsembleThreads},Int64,Array{Bool,1},Array{Bool,1}},Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}},CubaCuhre,Quadrature.ReCallVJP{Quadrature.ZygoteVJP},Array{Float64,1},Array{Float64,1},Array{Float64,1},Tuple{}}},Tuple{NTuple{8,Nothing},Tuple{}}}})(::Array{Float64,1}) at /Users/gerlacar/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [15] #solve#3 at /Users/gerlacar/.julia/dev/Quadrature/src/Quadrature.jl:60 [inlined]
 [16] (::typeof(∂(#solve#3)))(::Array{Float64,1}) at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [17] #175 at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182 [inlined]
 [18] (::Zygote.var"#347#back#177"{Zygote.var"#175#176"{typeof(∂(#solve#3)),Tuple{NTuple{5,Nothing},Tuple{}}}})(::Array{Float64,1}) at /Users/gerlacar/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [19] (::typeof(∂(solve##kw)))(::Array{Float64,1}) at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [20] #expectation#851 at /Users/gerlacar/.julia/dev/DiffEqUncertainty/src/koopman.jl:115 [inlined]
 [21] (::typeof(∂(#expectation#851)))(::Array{Float64,1}) at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [22] (::Zygote.var"#175#176"{typeof(∂(#expectation#851)),Tuple{NTuple{16,Nothing},Tuple{Nothing,Nothing}}})(::Array{Float64,1}) at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182
 [23] (::Zygote.var"#347#back#177"{Zygote.var"#175#176"{typeof(∂(#expectation#851)),Tuple{NTuple{16,Nothing},Tuple{Nothing,Nothing}}}})(::Array{Float64,1}) at /Users/gerlacar/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [24] (::typeof(∂(expectation##kw)))(::Array{Float64,1}) at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [25] (::Zygote.var"#175#176"{typeof(∂(expectation##kw)),Tuple{NTuple{9,Nothing},Tuple{Nothing}}})(::Array{Float64,1}) at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182
 [26] #347#back at /Users/gerlacar/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [27] #loss_koop3#211 at /Users/gerlacar/.julia/dev/DiffEqUncertainty/test/expectationAD.jl:63 [inlined]
 [28] (::typeof(∂(#loss_koop3#211)))(::Float64) at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [29] #175 at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182 [inlined]
 [30] #347#back at /Users/gerlacar/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [31] (::typeof(∂(loss_koop3##kw)))(::Float64) at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [32] #276 at ./REPL[36]:1 [inlined]
 [33] (::typeof(∂(#276)))(::Float64) at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [34] (::Zygote.var"#37#38"{typeof(∂(#276))})(::Float64) at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:45
 [35] gradient(::Function, ::Array{Float64,1}) at /Users/gerlacar/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:54
 [36] top-level scope at util.jl:175

@agerlach
Copy link
Collaborator Author

agerlach commented Sep 1, 2020

It looks like batch mode is now working for ForwardDiff, but Zygote still has issues. For R->R it is working, but for R^n->R, the first element of the gradient is the wrong value. Interestingly it is equal to the true solution * batchsize. Here, the actual batch size is 17. True solution is 8. Zygote is reporting 136, but 136/17=8.

using Quadrature, Cuba, Cubature, Zygote, FiniteDiff, ForwardDiff, Test
### Batch Single dim
f(x,p) = x*p[1].+p[2]*p[3]

lb =1.0
ub = 3.0
p = [2.0, 3.0, 4.0]
prob = QuadratureProblem(f,lb,ub,p)

function testf3(lb,ub,p; f=f)
    prob = QuadratureProblem(f,lb,ub,p, batch = 10, nout=1)
    solve(prob, CubatureJLh(); reltol=1e-3,abstol=1e-3)[1]
end

dp1 = ForwardDiff.gradient(p->testf3(lb,ub,p),p)
dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
dp3 = FiniteDiff.finite_difference_gradient(p->testf3(lb,ub,p),p)

@test dp1  dp3 #passes
@test dp2  dp3 #passes

### Batch multi dim
f(x,p) = x[1,:]*p[1].+p[2]*p[3]

lb =[1.0,1.0]
ub = [3.0,3.0]
p = [2.0, 3.0, 4.0]
prob = QuadratureProblem(f,lb,ub,p)

function testf3(lb,ub,p; f=f)
    prob = QuadratureProblem(f,lb,ub,p, batch = 10, nout=1)
    solve(prob, CubatureJLh(); reltol=1e-3,abstol=1e-3)[1]
end

dp1 = ForwardDiff.gradient(p->testf3(lb,ub,p),p)
dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
dp3 = FiniteDiff.finite_difference_gradient(p->testf3(lb,ub,p),p)

@test dp1  dp3 # passes
@test dp2  dp3 # Fail  [136.0,16.0,12.0] ≈ [8.0,16.0,12.0]

@ChrisRackauckas
Copy link
Member

change the title and lets merge this at least for now. That is... an odd Zygote behavior for sure haha. Probably some accidental referencing instead of copying.

@agerlach agerlach changed the title [WIP] Zygote adjoint updates towards working batch mode. ForwardDiff Batch Mode Support Sep 1, 2020
@agerlach
Copy link
Collaborator Author

agerlach commented Sep 1, 2020

Hold off on merging. I am going to push an additional test first

@agerlach
Copy link
Collaborator Author

agerlach commented Sep 1, 2020

OK, it is ready. I added the single and multi dim tests from above and added @test_broken for multi-dim Zygote

@agerlach
Copy link
Collaborator Author

agerlach commented Sep 1, 2020

I previously missed fix for iip. Should be corrected w/ test now. This now allows batch mode AD with DEU.

@ChrisRackauckas ChrisRackauckas merged commit 11feb79 into SciML:master Sep 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants