Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bidirectional RNN #708

Merged
merged 16 commits into from
Jul 10, 2024
Merged

Bidirectional RNN #708

merged 16 commits into from
Jul 10, 2024

Conversation

NeroBlackstone
Copy link
Contributor

issue #687

Please confirm whether the interface meets the requirements. Thank you.

Copy link

codecov bot commented Jun 16, 2024

Codecov Report

Attention: Patch coverage is 44.44444% with 10 lines in your changes missing coverage. Please review.

Project coverage is 87.00%. Comparing base (d448c43) to head (2dca0db).
Report is 1 commits behind head on main.

Files Patch % Lines
src/layers/recurrent.jl 0.00% 6 Missing ⚠️
ext/LuxReverseDiffExt/LuxReverseDiffExt.jl 33.33% 2 Missing ⚠️
ext/LuxTrackerExt.jl 33.33% 2 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (d448c43) and HEAD (2dca0db). Click for more details.

HEAD has 5 uploads less than BASE
Flag BASE (d448c43) HEAD (2dca0db)
42 37
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #708      +/-   ##
==========================================
- Coverage   96.30%   87.00%   -9.31%     
==========================================
  Files          57       57              
  Lines        2789     2801      +12     
==========================================
- Hits         2686     2437     -249     
- Misses        103      364     +261     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@NeroBlackstone NeroBlackstone changed the title [WIP] Bidirectional RNN Bidirectional RNN Jun 17, 2024
@NeroBlackstone
Copy link
Contributor Author

@avik-pal Hi! I have completed the implementation of Bidirectional and written test code, trying to be as equivalent to Keras' API. Please review the changes, thank you!

@NeroBlackstone
Copy link
Contributor Author

😭😭😭Any suggestions? I will fix it today

src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
test/layers/recurrent_tests.jl Outdated Show resolved Hide resolved
@avik-pal avik-pal linked an issue Jun 19, 2024 that may be closed by this pull request
@NeroBlackstone
Copy link
Contributor Author

@avik-pal 🥹🥹Hi, could you please help me review the PR and find the reason why @jet test failed? I haven't any idea about that.

Thank you very much ♥️♥️♥️♥️

@NeroBlackstone
Copy link
Contributor Author

@avik-pal Sorry to bother you. I still don't know how to solve runtime dispatch error for this julia code...

(rnn::BidirectionalRNN)(x, ps, st::NamedTuple) = rnn.model(x, ps, st)
julia> @report_opt bi_rnn(x, ps, st)
═════ 1 possible error found ═════
┌ (::BidirectionalRNN)(x::Array{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}) @ Lux ./Lux.jl/src/layers/recurrent.jl:690
│ runtime dispatch detected: %1::Parallel(x::Array{Float32, 3}, ps::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}}, st::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}})::Any
└────────────────────

@avik-pal
Copy link
Member

I will take a look on the weekend

@NeroBlackstone
Copy link
Contributor Author

I will take a look on the weekend

😥Hi, could you help me review this PR?..

@avik-pal avik-pal force-pushed the Bidirectional branch 5 times, most recently from b97e29a to e7ae725 Compare June 30, 2024 19:15
@avik-pal
Copy link
Member

Fix the gradient tests and it should be fine. They are probably originating from lazy reverse rrules for Zygote not being defined for GPU arrays

@NeroBlackstone
Copy link
Contributor Author

Fix the gradient tests and it should be fine. They are probably originating from lazy reverse rrules for Zygote not being defined for GPU arrays

Thank you for your help!

Some gradient tests still failed at here, I have no idea about how to set lazy reverse rrules... I guess the rest of the implementation may have to be left to you, thanks!

19:22:03 | maxrss 20.0% | mem 67.2% | DONE  (1/1) test item "Bidirectional" 112.9 secs (68.5% compile, 0.1% recompile, 6.2% GC), 188.79 M allocs (13.948 GB)
Test Summary:                          | Pass  Error  Total     Time
ReTestItem Tests                       |  120     12    132  2m08.7s
  Bidirectional                        |   60      6     66  2m03.7s
    cpu                                |   30      3     33  1m08.8s
      cell: RNNCell                    |   10      1     11    52.2s
      cell: LSTMCell                   |   10      1     11     8.0s
      cell: GRUCell                    |   10      1     11     8.6s
    cuda                               |   30      3     33    43.5s
      cell: RNNCell                    |   10      1     11    27.5s
      cell: LSTMCell                   |   10      1     11     7.8s
      cell: GRUCell                    |   10      1     11     8.3s
  Lux                                  |   60      6     66  2m05.6s
    test                               |   60      6     66         
      test/layers                      |   60      6     66         
        test/layers/recurrent_tests.jl |   60      6     66         
          Bidirectional                |   60      6     66  2m03.7s
            cpu                        |   30      3     33  1m08.8s
              cell: RNNCell            |   10      1     11    52.2s
              cell: LSTMCell           |   10      1     11     8.0s
              cell: GRUCell            |   10      1     11     8.6s
            cuda                       |   30      3     33    43.5s
              cell: RNNCell            |   10      1     11    27.5s
              cell: LSTMCell           |   10      1     11     7.8s
              cell: GRUCell            |   10      1     11     8.3s
ERROR: LoadError: Some tests did not pass: 120 passed, 0 failed, 12 errored, 0 broken.
in expression starting at /home/nero/Documents/github/Lux.jl/test/runtests.jl:75

@NeroBlackstone
Copy link
Contributor Author

@avik-pal Hi, I'm sorry to bother you again. I want to continue to push this PR, but it seems that there is a limit to what I can do. Is "lazy reverse rrules" a feature that Zygote is missing? Do I need to open an issue for Zygote?

@avik-pal
Copy link
Member

avik-pal commented Jul 4, 2024

Kind of, but not worth opening a Zygote issue for this.

If you look at https://buildkite.com/julialang/lux-dot-jl/builds/2994#01906a93-9e6b-4704-a6a1-d3b8e82bb694/350-1748, it is saying that we doing a broadcast vcat of Vector{<:CuArray} and Iterators.Reverse{<:CuArray}. Zygote doesn't have a rule for that.

The easiest way to resolve this would be to make the Iterators.Reverse into a Vector, since that is effectively materializing a vector of pointers, it is not expensive either.

@NeroBlackstone
Copy link
Contributor Author

make the Iterators.Reverse into a Vector,

@avik-pal Thank you very much for pointing me in the right direction, but I can't really understand "make the Iterators.Reverse into a Vector" here, I think Iterators.Reverse already support vector:

julia> vec = [1,2,3,4,5]
julia> foreach(println, Iterators.reverse(vec))
5
4
3
2
1

Could you give me a few inputs and outputs as examples, and the signature of the function to define? I can implement it if I could

@avik-pal
Copy link
Member

avik-pal commented Jul 7, 2024

I think Iterators.Reverse already support vector:

No that is not what I meant. Try something like

x = [cu(rand(5)) for _ in 1:5]
x_rev = Iterators.reverse(x)

vcat(x, x_rev)

Now try to differentiate the result of vcat wrt x.

@NeroBlackstone
Copy link
Contributor Author

Now try to differentiate the result of vcat wrt x.

@avik-pal So sorry to bother you... Forgive my poor understanding...

I totally don't understand the math meaning here! So I don't know how to differentiate this vcat function, how to express this use Zygote?

Let's say:

x = [1,2,3,4,5]
cat_x = [1,2,3,4,5,1,2,3,4,5]
y = cat_x
# how to differentiate the result???

Thanks all your OSS works, It's great. And I really need this feature.

@avik-pal avik-pal force-pushed the Bidirectional branch 2 times, most recently from b0e692e to 5692dfe Compare July 10, 2024 01:15
@avik-pal avik-pal merged commit d9aa5a6 into LuxDL:main Jul 10, 2024
56 of 64 checks passed
@NeroBlackstone NeroBlackstone deleted the Bidirectional branch July 13, 2024 02:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature request: Bidirectional for RNN layer.
2 participants