-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Uncaught LLVM-level error from neural SDE #429
Comments
Hardcoding to EnzymeVJP makes it segafult differently: using SciMLSensitivity, Flux, LinearAlgebra
using DiffEqNoiseProcess
using StochasticDiffEq
using Statistics
using SciMLSensitivity
using DiffEqBase.EnsembleAnalysis
using Zygote
using Optimization, OptimizationFlux
using Random
Random.seed!(238248735)
x_size = 2 # Size of the spatial dimensions in the SDE
v_size = 2 # Output size of the control
# Define Neural Network for the control input
input_size = x_size + 1 # size of the spatial dimensions PLUS one time dimensions
nn_initial = Chain(Dense(input_size, v_size)) # The actual neural network
p_nn, model = Flux.destructure(nn_initial)
nn(x, p) = model(p)(x)
# Define the right hand side of the SDE
const_mat = zeros(Float64, (x_size, v_size))
for i in 1:max(x_size, v_size)
const_mat[i, i] = 1
end
function f!(du, u, p, t)
MM = nn([u; t], p)
du .= u + const_mat * MM
end
function g!(du, u, p, t)
du .= false * u .+ sqrt(2 * 0.001)
end
# Define SDE problem
u0 = vec(rand(Float64, (x_size, 1)))
tspan = (0.0, 1.0)
ts = collect(0:0.1:1)
prob = SDEProblem{true}(f!, g!, u0, tspan, p_nn)
W = WienerProcess(0.0, 0.0, 0.0)
probscalar = SDEProblem{true}(f!, g!, u0, tspan, p_nn, noise=W)
# Defining the loss function
function loss(pars, prob, alg)
function prob_func(prob, i, repeat)
# Prepare new initial state and remake the problem
u0tmp = vec(rand(Float64, (x_size, 1)))
remake(prob, p=pars, u0=u0tmp)
end
ensembleprob = EnsembleProblem(prob, prob_func=prob_func)
_sol = solve(ensembleprob, alg, EnsembleThreads(), sensealg=BacksolveAdjoint(autojacvec=EnzymeVJP()),
saveat=ts, trajectories=10,
abstol=1e-1, reltol=1e-1)
A = convert(Array, _sol)
sum(abs2, A .- 1), mean(A)
end
# Actually training/fitting the model
losses = []
function callback(θ, l, pred)
begin
push!(losses, l)
if length(losses) % 1 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
false
end
end
optf = Optimization.OptimizationFunction((p, _) -> loss(p, probscalar, LambaEM()),
Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optf, p_nn)
res1 = Optimization.solve(optprob, ADAM(0.1), callback=callback, maxiters=5)
|
Note that I'd expect this to (currently) fail with Enzyme, the issue is mostly that it fails without becing caught. |
To confirm have you tried this on latest main? |
|
ChrisRackauckas
added a commit
to SciML/SciMLSensitivity.jl
that referenced
this issue
Aug 26, 2022
Just avoid EnzymeAD/Enzyme.jl#429
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Reducing this one is hard since it just crashes everything.
The text was updated successfully, but these errors were encountered: