Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bidirectional RNN #708

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Conversation

NeroBlackstone
Copy link
Contributor

issue #687

Please confirm whether the interface meets the requirements. Thank you.

Copy link

codecov bot commented Jun 16, 2024

Codecov Report

Attention: Patch coverage is 0% with 6 lines in your changes missing coverage. Please review.

Project coverage is 76.81%. Comparing base (1a61165) to head (e7ae725).

Files Patch % Lines
src/layers/recurrent.jl 0.00% 6 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (1a61165) and HEAD (e7ae725). Click for more details.

HEAD has 25 uploads less than BASE
Flag BASE (1a61165) HEAD (e7ae725)
38 13
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #708       +/-   ##
===========================================
- Coverage   96.40%   76.81%   -19.60%     
===========================================
  Files          54       54               
  Lines        2726     2730        +4     
===========================================
- Hits         2628     2097      -531     
- Misses         98      633      +535     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@NeroBlackstone NeroBlackstone changed the title [WIP] Bidirectional RNN Bidirectional RNN Jun 17, 2024
if mode != "AMDGPU"
# gradients test failed after vcat
# __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
# @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gradient test is incorrect for the matrix after vcat. I'm not sure if the gradient test code is written incorrectly, or if my Bidirectional implementation is incorrect

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which backend was failing? The basic pipeline is compute with Zygote. Then compute with FiniteDifferences, ReverseDiff, Tracker,... and compare them to Zygote.

@NeroBlackstone
Copy link
Contributor Author

@avik-pal Hi! I have completed the implementation of Bidirectional and written test code, trying to be as equivalent to Keras' API. Please review the changes, thank you!

@NeroBlackstone
Copy link
Contributor Author

😭😭😭Any suggestions? I will fix it today

src/layers/recurrent.jl Outdated Show resolved Hide resolved

## Parameters

- Same as `cell` and `backward_cell`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be NamedTuple with cell and backward_cell

src/layers/recurrent.jl Outdated Show resolved Hide resolved
backward_cell::Union{AbstractRecurrentCell, Nothing}=nothing,
merge_mode::Union{Function, Nothing}=vcat,
ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex())
if !isnothing(backward_cell) && cell.in_dims != backward_cell.in_dims
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check via backward_cell !== nothing, there used to be some performance consideration for not using isnothing (not sure if that is true for current julia versions but better be safe)

src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
backward_rnn_layer = isnothing(backward_cell) ? deepcopy(layer) :
Recurrence(backward_cell; return_sequence=true, ordering)
fuse_op = isnothing(merge_mode) ? nothing : Broadcast.BroadcastFunction(merge_mode)
return Parallel(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking maybe we create a struct

struct BidirectionalRNN <: ....
    model
end

(rnn::...)(...) = rnn.model(....)

This is mostly to be consistent with other wrapper layers in Lux which return the same type and not Parallel (as in this case)

if mode != "AMDGPU"
# gradients test failed after vcat
# __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
# @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which backend was failing? The basic pipeline is compute with Zygote. Then compute with FiniteDifferences, ReverseDiff, Tracker,... and compare them to Zygote.

@avik-pal avik-pal linked an issue Jun 19, 2024 that may be closed by this pull request
y_, st__ = bi_rnn_no_merge(x, ps, st)

@jet bi_rnn(x, ps, st)
@jet bi_rnn_no_merge(x, ps, st)
Copy link
Contributor Author

@NeroBlackstone NeroBlackstone Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avik-pal Hi, I have updated the implementation, but the @jet test fails and I can't find the reason. Could you please take a look at it when you have time? 😭😭😭

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the tests wont really tell much here. Install JET.jl and run @report_call bi_rnn_no_merge(x, ps, st) target_modules=[Lux, LuxLib, LuxCore] and @report_opt bi_rnn_no_merge(x, ps, st) target_modules=[Lux, LuxLib, LuxCore]. This will give you a list of function calls that cause the test to fail

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avik-pal Hi, I tried @report_call bi_rnn_no_merge(x, ps, st) target_modules=[Lux, LuxLib, LuxCore] but it report @report_call expects only one non-keyword argument, so I remove target_modules=[Lux, LuxLib, LuxCore]. And I replace @jet to @report_call and @report_opt, all test passed and no error throw.

Is this LuxTestUtils bug?,,, Could you please check this branch locally? Thank you very much!!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no don't replace @jet with report_* functions, you need to run those 2 in the REPL / VSCode. That will point to the correct line. @report_call target_modules=[Lux, LuxLib, LuxCore] bi_rnn_no_merge(x, ps, st) https://aviatesk.github.io/JET.jl/dev/tutorial/#Analyse-methods-with-@report_call

@NeroBlackstone
Copy link
Contributor Author

@avik-pal 🥹🥹Hi, could you please help me review the PR and find the reason why @jet test failed? I haven't any idea about that.

Thank you very much ♥️♥️♥️♥️

model::Parallel
end

(rnn::BidirectionalRNN)(x, ps, st::NamedTuple) = rnn.model(x, ps, st)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avik-pal Hi! JET.jl @report_opt bi_rnn(x, ps, st) report:

(::BidirectionalRNN)(x::Array{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}) @ Lux ./Lux.jl/src/layers/recurrent.jl:690
│ runtime dispatch detected: %1::Parallel(x::Array{Float32, 3}, ps::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}}, st::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}})::Any

I have no idea why it say runtime dispatch here... Chain has same function signature (c::Chain)(x, ps, st::NamedTuple) = applychain(c.layers, x, ps, st)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should have more information in the report

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🥲 I'm sorry, it has only two lines report, like I post above

@NeroBlackstone
Copy link
Contributor Author

@avik-pal Sorry to bother you. I still don't know how to solve runtime dispatch error for this julia code...

(rnn::BidirectionalRNN)(x, ps, st::NamedTuple) = rnn.model(x, ps, st)
julia> @report_opt bi_rnn(x, ps, st)
═════ 1 possible error found ═════
┌ (::BidirectionalRNN)(x::Array{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}) @ Lux ./Lux.jl/src/layers/recurrent.jl:690
│ runtime dispatch detected: %1::Parallel(x::Array{Float32, 3}, ps::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}}, st::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}})::Any
└────────────────────

@avik-pal
Copy link
Member

I will take a look on the weekend

@NeroBlackstone
Copy link
Contributor Author

I will take a look on the weekend

😥Hi, could you help me review this PR?..

@avik-pal avik-pal force-pushed the Bidirectional branch 5 times, most recently from b97e29a to e7ae725 Compare June 30, 2024 19:15
@avik-pal
Copy link
Member

Fix the gradient tests and it should be fine. They are probably originating from lazy reverse rrules for Zygote not being defined for GPU arrays

@NeroBlackstone
Copy link
Contributor Author

Fix the gradient tests and it should be fine. They are probably originating from lazy reverse rrules for Zygote not being defined for GPU arrays

Thank you for your help!

Some gradient tests still failed at here, I have no idea about how to set lazy reverse rrules... I guess the rest of the implementation may have to be left to you, thanks!

19:22:03 | maxrss 20.0% | mem 67.2% | DONE  (1/1) test item "Bidirectional" 112.9 secs (68.5% compile, 0.1% recompile, 6.2% GC), 188.79 M allocs (13.948 GB)
Test Summary:                          | Pass  Error  Total     Time
ReTestItem Tests                       |  120     12    132  2m08.7s
  Bidirectional                        |   60      6     66  2m03.7s
    cpu                                |   30      3     33  1m08.8s
      cell: RNNCell                    |   10      1     11    52.2s
      cell: LSTMCell                   |   10      1     11     8.0s
      cell: GRUCell                    |   10      1     11     8.6s
    cuda                               |   30      3     33    43.5s
      cell: RNNCell                    |   10      1     11    27.5s
      cell: LSTMCell                   |   10      1     11     7.8s
      cell: GRUCell                    |   10      1     11     8.3s
  Lux                                  |   60      6     66  2m05.6s
    test                               |   60      6     66         
      test/layers                      |   60      6     66         
        test/layers/recurrent_tests.jl |   60      6     66         
          Bidirectional                |   60      6     66  2m03.7s
            cpu                        |   30      3     33  1m08.8s
              cell: RNNCell            |   10      1     11    52.2s
              cell: LSTMCell           |   10      1     11     8.0s
              cell: GRUCell            |   10      1     11     8.6s
            cuda                       |   30      3     33    43.5s
              cell: RNNCell            |   10      1     11    27.5s
              cell: LSTMCell           |   10      1     11     7.8s
              cell: GRUCell            |   10      1     11     8.3s
ERROR: LoadError: Some tests did not pass: 120 passed, 0 failed, 12 errored, 0 broken.
in expression starting at /home/nero/Documents/github/Lux.jl/test/runtests.jl:75

@NeroBlackstone
Copy link
Contributor Author

@avik-pal Hi, I'm sorry to bother you again. I want to continue to push this PR, but it seems that there is a limit to what I can do. Is "lazy reverse rrules" a feature that Zygote is missing? Do I need to open an issue for Zygote?

@avik-pal
Copy link
Member

avik-pal commented Jul 4, 2024

Kind of, but not worth opening a Zygote issue for this.

If you look at https://buildkite.com/julialang/lux-dot-jl/builds/2994#01906a93-9e6b-4704-a6a1-d3b8e82bb694/350-1748, it is saying that we doing a broadcast vcat of Vector{<:CuArray} and Iterators.Reverse{<:CuArray}. Zygote doesn't have a rule for that.

The easiest way to resolve this would be to make the Iterators.Reverse into a Vector, since that is effectively materializing a vector of pointers, it is not expensive either.

@NeroBlackstone
Copy link
Contributor Author

make the Iterators.Reverse into a Vector,

@avik-pal Thank you very much for pointing me in the right direction, but I can't really understand "make the Iterators.Reverse into a Vector" here, I think Iterators.Reverse already support vector:

julia> vec = [1,2,3,4,5]
julia> foreach(println, Iterators.reverse(vec))
5
4
3
2
1

Could you give me a few inputs and outputs as examples, and the signature of the function to define? I can implement it if I could

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature request: Bidirectional for RNN layer.
2 participants