-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bidirectional RNN #708
base: main
Are you sure you want to change the base?
Bidirectional RNN #708
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #708 +/- ##
===========================================
- Coverage 96.40% 76.81% -19.60%
===========================================
Files 54 54
Lines 2726 2730 +4
===========================================
- Hits 2628 2097 -531
- Misses 98 633 +535 ☔ View full report in Codecov by Sentry. |
test/layers/recurrent_tests.jl
Outdated
if mode != "AMDGPU" | ||
# gradients test failed after vcat | ||
# __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st))) | ||
# @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The gradient test is incorrect for the matrix after vcat
. I'm not sure if the gradient test code is written incorrectly, or if my Bidirectional implementation is incorrect
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which backend was failing? The basic pipeline is compute with Zygote. Then compute with FiniteDifferences, ReverseDiff, Tracker,... and compare them to Zygote.
@avik-pal Hi! I have completed the implementation of |
😭😭😭Any suggestions? I will fix it today |
src/layers/recurrent.jl
Outdated
|
||
## Parameters | ||
|
||
- Same as `cell` and `backward_cell`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be NamedTuple
with cell
and backward_cell
src/layers/recurrent.jl
Outdated
backward_cell::Union{AbstractRecurrentCell, Nothing}=nothing, | ||
merge_mode::Union{Function, Nothing}=vcat, | ||
ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex()) | ||
if !isnothing(backward_cell) && cell.in_dims != backward_cell.in_dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check via backward_cell !== nothing
, there used to be some performance consideration for not using isnothing
(not sure if that is true for current julia versions but better be safe)
src/layers/recurrent.jl
Outdated
backward_rnn_layer = isnothing(backward_cell) ? deepcopy(layer) : | ||
Recurrence(backward_cell; return_sequence=true, ordering) | ||
fuse_op = isnothing(merge_mode) ? nothing : Broadcast.BroadcastFunction(merge_mode) | ||
return Parallel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking maybe we create a struct
struct BidirectionalRNN <: ....
model
end
(rnn::...)(...) = rnn.model(....)
This is mostly to be consistent with other wrapper layers in Lux which return the same type and not Parallel
(as in this case)
test/layers/recurrent_tests.jl
Outdated
if mode != "AMDGPU" | ||
# gradients test failed after vcat | ||
# __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st))) | ||
# @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which backend was failing? The basic pipeline is compute with Zygote. Then compute with FiniteDifferences, ReverseDiff, Tracker,... and compare them to Zygote.
y_, st__ = bi_rnn_no_merge(x, ps, st) | ||
|
||
@jet bi_rnn(x, ps, st) | ||
@jet bi_rnn_no_merge(x, ps, st) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@avik-pal Hi, I have updated the implementation, but the @jet
test fails and I can't find the reason. Could you please take a look at it when you have time? 😭😭😭
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the tests wont really tell much here. Install JET.jl
and run @report_call bi_rnn_no_merge(x, ps, st) target_modules=[Lux, LuxLib, LuxCore]
and @report_opt bi_rnn_no_merge(x, ps, st) target_modules=[Lux, LuxLib, LuxCore]
. This will give you a list of function calls that cause the test to fail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@avik-pal Hi, I tried @report_call bi_rnn_no_merge(x, ps, st) target_modules=[Lux, LuxLib, LuxCore]
but it report @report_call expects only one non-keyword argument
, so I remove target_modules=[Lux, LuxLib, LuxCore]
. And I replace @jet
to @report_call
and @report_opt
, all test passed and no error throw.
Is this LuxTestUtils
bug?,,, Could you please check this branch locally? Thank you very much!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no don't replace @jet
with report_*
functions, you need to run those 2 in the REPL / VSCode. That will point to the correct line. @report_call target_modules=[Lux, LuxLib, LuxCore] bi_rnn_no_merge(x, ps, st)
https://aviatesk.github.io/JET.jl/dev/tutorial/#Analyse-methods-with-@report_call
@avik-pal 🥹🥹Hi, could you please help me review the PR and find the reason why Thank you very much |
model::Parallel | ||
end | ||
|
||
(rnn::BidirectionalRNN)(x, ps, st::NamedTuple) = rnn.model(x, ps, st) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@avik-pal Hi! JET.jl
@report_opt bi_rnn(x, ps, st)
report:
(::BidirectionalRNN)(x::Array{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}) @ Lux ./Lux.jl/src/layers/recurrent.jl:690
│ runtime dispatch detected: %1::Parallel(x::Array{Float32, 3}, ps::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}}, st::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}})::Any
I have no idea why it say runtime dispatch
here... Chain has same function signature (c::Chain)(x, ps, st::NamedTuple) = applychain(c.layers, x, ps, st)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should have more information in the report
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥲 I'm sorry, it has only two lines report, like I post above
@avik-pal Sorry to bother you. I still don't know how to solve (rnn::BidirectionalRNN)(x, ps, st::NamedTuple) = rnn.model(x, ps, st) julia> @report_opt bi_rnn(x, ps, st)
═════ 1 possible error found ═════
┌ (::BidirectionalRNN)(x::Array{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}) @ Lux ./Lux.jl/src/layers/recurrent.jl:690
│ runtime dispatch detected: %1::Parallel(x::Array{Float32, 3}, ps::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}}, st::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}})::Any
└──────────────────── |
I will take a look on the weekend |
😥Hi, could you help me review this PR?.. |
b97e29a
to
e7ae725
Compare
Fix the gradient tests and it should be fine. They are probably originating from lazy reverse rrules for Zygote not being defined for GPU arrays |
Thank you for your help! Some gradient tests still failed at here, I have no idea about how to set 19:22:03 | maxrss 20.0% | mem 67.2% | DONE (1/1) test item "Bidirectional" 112.9 secs (68.5% compile, 0.1% recompile, 6.2% GC), 188.79 M allocs (13.948 GB)
Test Summary: | Pass Error Total Time
ReTestItem Tests | 120 12 132 2m08.7s
Bidirectional | 60 6 66 2m03.7s
cpu | 30 3 33 1m08.8s
cell: RNNCell | 10 1 11 52.2s
cell: LSTMCell | 10 1 11 8.0s
cell: GRUCell | 10 1 11 8.6s
cuda | 30 3 33 43.5s
cell: RNNCell | 10 1 11 27.5s
cell: LSTMCell | 10 1 11 7.8s
cell: GRUCell | 10 1 11 8.3s
Lux | 60 6 66 2m05.6s
test | 60 6 66
test/layers | 60 6 66
test/layers/recurrent_tests.jl | 60 6 66
Bidirectional | 60 6 66 2m03.7s
cpu | 30 3 33 1m08.8s
cell: RNNCell | 10 1 11 52.2s
cell: LSTMCell | 10 1 11 8.0s
cell: GRUCell | 10 1 11 8.6s
cuda | 30 3 33 43.5s
cell: RNNCell | 10 1 11 27.5s
cell: LSTMCell | 10 1 11 7.8s
cell: GRUCell | 10 1 11 8.3s
ERROR: LoadError: Some tests did not pass: 120 passed, 0 failed, 12 errored, 0 broken.
in expression starting at /home/nero/Documents/github/Lux.jl/test/runtests.jl:75 |
@avik-pal Hi, I'm sorry to bother you again. I want to continue to push this PR, but it seems that there is a limit to what I can do. Is "lazy reverse rrules" a feature that Zygote is missing? Do I need to open an issue for Zygote? |
Kind of, but not worth opening a Zygote issue for this. If you look at https://buildkite.com/julialang/lux-dot-jl/builds/2994#01906a93-9e6b-4704-a6a1-d3b8e82bb694/350-1748, it is saying that we doing a broadcast The easiest way to resolve this would be to make the Iterators.Reverse into a Vector, since that is effectively materializing a vector of pointers, it is not expensive either. |
@avik-pal Thank you very much for pointing me in the right direction, but I can't really understand "make the Iterators.Reverse into a Vector" here, I think julia> vec = [1,2,3,4,5]
julia> foreach(println, Iterators.reverse(vec))
5
4
3
2
1 Could you give me a few inputs and outputs as examples, and the signature of the function to define? I can implement it if I could |
issue #687
Please confirm whether the interface meets the requirements. Thank you.