Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Bidirectional for RNN layer. #687

Open
NeroBlackstone opened this issue Jun 6, 2024 · 10 comments · May be fixed by #708
Open

Feature request: Bidirectional for RNN layer. #687

NeroBlackstone opened this issue Jun 6, 2024 · 10 comments · May be fixed by #708

Comments

@NeroBlackstone
Copy link
Contributor

Bidirectional(Recurrence(LSTMCell(in_dims => hidden_dims)))
# or
Recurrence(LSTMCell(in_dims => hidden_dims), bidirectional=true)

Which one is better and easier to implement in Lux.jl? I'm willing to try to implement it and open a pr, but may need some guidance.

@avik-pal
Copy link
Member

avik-pal commented Jun 6, 2024

I think Bidirectional RNNs can be implemented as:

Parallel(
   <fuse_op>, # User provided
   Recurrence(<rnn cell>; return_sequence=true),
   Chain(
       ReverseSequence(), # Needs to be implemented
       Recurrence(<rnn cell>; return_sequence=true)
   )
)

It would be nice to have a Bidirectional(<fuse_op>, <rnn cell>) constructor which does the above.

I will have to check what other frameworks do but I am not sure if the two direction parameters are shared between the layers.

I'm willing to try to implement it and open a pr, but may need some guidance.

For sure!

@NeroBlackstone
Copy link
Contributor Author

Parallel(
   <fuse_op>, # User provided
   Recurrence(<rnn cell>; return_sequence=true),
   Chain(
       ReverseSequence(), # Needs to be implemented
       Recurrence(<rnn cell>; return_sequence=true)
   )
)

Hi! I would like to ask what role does <fuse_op> play here. If it is for concatenating hidden states, then it should be implemented by Lux rather than provided by the user?

@NeroBlackstone
Copy link
Contributor Author

NeroBlackstone commented Jun 12, 2024

So I think ReverseSequence() is a Helper Layer that can reverse the specified dimension of the input array?

So it should like:

x = [1 2; 3 4]

model = ReverseSequence(1)
y,st_new = model(x,ps,st)
# y =
 [3, 4]
 [1, 2]

model = ReverseSequence(2)
y,st_new = model(x,ps,st)
# y =
 [2, 4]
 [1, 3]

Is my understanding correct? Are there any suggestions for high-performance implementation?

@NeroBlackstone
Copy link
Contributor Author

NeroBlackstone commented Jun 13, 2024

I think ReverseSequence should be added to the output of the backward RNN layer, so that the output of the first time step of the forward RNN corresponds to the output of the last time step of the backward RNN:

Parallel(
   <fuse_op>, # User provided
   Recurrence(<rnn cell>; return_sequence=true),
   Chain(
       ReverseSequence(), # Needs to be implemented
       Recurrence(<rnn cell>; return_sequence=true),
       ReverseSequence()
   )
)

And what <fuse_op> should we use?... The output size of RNN is seq_len * (hidden_state_len * batch_size), <fuse_op> should actually be a broadcast vcat operation, but Parallel's connection does not support broadcast.

@avik-pal
Copy link
Member

Tensorflow allows you to choose fuse_ops that's why I want it to be user choice. For broadcasting just wrap the user input (default to vcat) with BroadcastFunciton(<op>). BroadcastFunction is in base.

@NeroBlackstone
Copy link
Contributor Author

Tensorflow allows you to choose fuse_ops that's why I want it to be user choice. For broadcasting just wrap the user input (default to vcat) with BroadcastFunciton(<op>). BroadcastFunction is in base.

keras provides some default implementations ({"sum", "mul", "concat", "ave", None}), should we provide them too?

Or provide a default implementation of concat and a custom <fuse_op> option

@NeroBlackstone
Copy link
Contributor Author

Tensorflow allows you to choose fuse_ops that's why I want it to be user choice. For broadcasting just wrap the user input (default to vcat) with BroadcastFunciton(<op>). BroadcastFunction is in base.

I can't find Base.BroadcastFunction in Julia documentation, could you please give me a detialed code about how to use broadcast function in Parallel Layer? I still think this is impossible

@NeroBlackstone
Copy link
Contributor Author

could you please give me a detialed code about how to use broadcast function in Parallel Layer? I still think this is impossible

Ok.. I figure out..

bvcat(a,b) = vcat.(a,b)
model = Parallel(bvcat,
    Recurrence(GRUCell(3=>2),return_sequence = true),
    Recurrence(GRUCell(3=>2),return_sequence = true)
    )

@NeroBlackstone
Copy link
Contributor Author

I have roughly figured out how to build a bidirectional RNN using Parallel layers. I will start implementing it once #698 is merged.

@avik-pal
Copy link
Member

help?> Broadcast.BroadcastFunction
  BroadcastFunction{F} <: Function


  Represents the "dotted" version of an operator, which broadcasts the operator over its arguments, so BroadcastFunction(op) is functionally equivalent to (x...) -> (op).(x...).

  Can be created by just passing an operator preceded by a dot to a higher-order function.

  Examples
  ≡≡≡≡≡≡≡≡

  julia> a = [[1 3; 2 4], [5 7; 6 8]];
  
  julia> b = [[9 11; 10 12], [13 15; 14 16]];
  
  julia> map(.*, a, b)
  2-element Vector{Matrix{Int64}}:
   [9 33; 20 48]
   [65 105; 84 128]
  
  julia> Base.BroadcastFunction(+)(a, b) == a .+ b
  true


  │ Julia 1.6
  │
  │  BroadcastFunction and the standalone .op syntax are available as of Julia 1.6.

This is basically what you wrote, just do Broadcast.BroadcastedFunction(vcat)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants