Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bidirectional RNN #708

Merged
merged 16 commits into from
Jul 10, 2024
1 change: 1 addition & 0 deletions docs/src/api/Lux/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ LSTMCell
RNNCell
Recurrence
StatefulRecurrentCell
BidirectionalRNN
```

## Linear Layers
Expand Down
19 changes: 19 additions & 0 deletions ext/LuxReverseDiffExt/LuxReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@ Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st)
@inline Lux.__eltype(::TrackedReal{T}) where {T} = T
@inline Lux.__eltype(::AbstractArray{<:TrackedReal{T}}) where {T} = T

@inline Lux.__reverse(x::TrackedArray; dims=:) = ArrayInterface.aos_to_soa(reverse(x; dims))
@inline function Lux.__reverse(x::AbstractArray{<:TrackedReal}; dims=:)
return ArrayInterface.aos_to_soa(reverse(x; dims))
end

# multigate: avoid soa formation
@inline function Lux._gate(x::TrackedArray{T, R, 1}, h::Int, n::Int) where {T, R}
return x[Lux._gate(h, n)]
end
@inline function Lux._gate(x::AbstractVector{<:TrackedReal}, h::Int, n::Int)
return ArrayInterface.aos_to_soa(view(x, Lux._gate(h, n)))
end
@inline function Lux._gate(x::TrackedArray{T, R, 2}, h::Int, n::Int) where {T, R}
return x[Lux._gate(h, n), :]
end
@inline function Lux._gate(x::AbstractMatrix{<:TrackedReal}, h::Int, n::Int)
return ArrayInterface.aos_to_soa(view(x, Lux._gate(h, n), :))
end

@inline function Lux.__convert_eltype(::Type{T}, x::AbstractArray{<:TrackedReal}) where {T}
@warn "`Lux.__convert_eltype` doesn't support converting element types of ReverseDiff \
`TrackedReal` arrays. Currently this is a no-op." maxlog=1
Expand Down
5 changes: 5 additions & 0 deletions ext/LuxTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st)
@inline Lux.__eltype(::TrackedReal{T}) where {T} = T
@inline Lux.__eltype(::AbstractArray{<:TrackedReal{T}}) where {T} = T

@inline Lux.__reverse(x::TrackedArray; dims=:) = ArrayInterface.aos_to_soa(reverse(x; dims))
@inline function Lux.__reverse(x::AbstractArray{<:TrackedReal}; dims=:)
return ArrayInterface.aos_to_soa(reverse(x; dims))
end

# SimpleChains.jl: DON'T REPLACE THESE WITH @grad_from_chainrules
for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray)
T1 === :AbstractArray && T2 === :AbstractArray && continue
Expand Down
8 changes: 4 additions & 4 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ include("preferences.jl")
include("custom_errors.jl")
include("utils.jl")

# Experimental
include("contrib/contrib.jl")

# Layer Implementations
include("layers/basic.jl")
include("layers/containers.jl")
Expand All @@ -54,9 +57,6 @@ include("layers/extension.jl")
# Pretty Printing
include("layers/display.jl")

# Experimental
include("contrib/contrib.jl")

# Helpful Functionalities
include("helpers/stateful.jl")
include("helpers/compact.jl")
Expand Down Expand Up @@ -93,7 +93,7 @@ export AlphaDropout, Dropout, VariationalHiddenDropout
export BatchNorm, GroupNorm, InstanceNorm, LayerNorm
export WeightNorm
export NoOpLayer, ReshapeLayer, SelectDim, FlattenLayer, WrappedFunction, ReverseSequence
export RNNCell, LSTMCell, GRUCell, Recurrence, StatefulRecurrentCell
export RNNCell, LSTMCell, GRUCell, Recurrence, StatefulRecurrentCell, BidirectionalRNN
export SamePad, TimeLastIndex, BatchLastIndex

export StatefulLuxLayer
Expand Down
2 changes: 1 addition & 1 deletion src/contrib/debug.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ See [`Lux.Experimental.@debug_mode`](@ref) to construct this layer.
"""
@concrete struct DebugLayer{NaNCheck, ErrorCheck} <:
AbstractExplicitContainerLayer{(:layer,)}
layer
layer <: AbstractExplicitLayer
location::KeyPath
end

Expand Down
10 changes: 5 additions & 5 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,22 @@ end

@inline function (r::ReverseSequence{Nothing})(
x::AbstractVector{T}, ps, st::NamedTuple) where {T}
return (isbitstype(T) ? reverse(x) : Iterators.reverse(x)), st
return __reverse(x), st
end

@inline function (r::ReverseSequence{Nothing})(
x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N}
return reverse(x; dims=ndims(x) - 1), st
return __reverse(x; dims=ndims(x) - 1), st
end

@inline function (r::ReverseSequence)(x::AbstractVector{T}, ps, st::NamedTuple) where {T}
r.dim == 1 && return reverse(x), st
throw(DimensionMismatch(lazy"Cannot specify a dimension other than 1 for AbstractVector{T}"))
r.dim == 1 && return __reverse(x), st
throw(ArgumentError("Cannot specify a dimension other than 1 for AbstractVector{T}"))
end

@inline function (r::ReverseSequence)(
x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N}
return reverse(x; dims=r.dim), st
return __reverse(x; dims=r.dim), st
end

"""
Expand Down
6 changes: 6 additions & 0 deletions src/layers/display.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ function _printable_children(l::Union{PairwiseFusion, Parallel})
return merge((; l.connection), children.layers)
end
_printable_children(l::SkipConnection) = (; l.connection, l.layers)
function _printable_children(l::BidirectionalRNN)
merge_mode = l.model.connection isa Broadcast.BroadcastFunction ? l.model.connection.f :
nothing
return (; merge_mode, forward_cell=l.model.layers.forward_rnn.cell,
backward_cell=l.model.layers.backward_rnn.rnn.cell)
end

_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for:
_show_leaflike(x::AbstractExplicitLayer) = false
Expand Down
91 changes: 82 additions & 9 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
abstract type AbstractRecurrentCell{use_bias, train_state} <: AbstractExplicitLayer end

const AbstractDebugRecurrentCell = Experimental.DebugLayer{
<:Any, <:Any, <:AbstractRecurrentCell}

function ConstructionBase.constructorof(::Type{<:AbstractRecurrentCell{
use_bias, train_state}}) where {use_bias, train_state}
return AbstractRecurrentCell{use_bias, train_state}
end

# Fallback for vector inputs
function (rnn::AbstractRecurrentCell)(x::AbstractVector, ps, st::NamedTuple)
(y, carry), st_ = rnn(reshape(x, :, 1), ps, st)
Expand Down Expand Up @@ -82,16 +90,16 @@ automatically operate over a sequence of inputs.

For some discussion on this topic, see https://github.com/LuxDL/Lux.jl/issues/472.
"""
struct Recurrence{
R, C <: AbstractRecurrentCell, O <: AbstractTimeSeriesDataBatchOrdering} <:
AbstractExplicitContainerLayer{(:cell,)}
cell::C
ordering::O
@concrete struct Recurrence{R} <: AbstractExplicitContainerLayer{(:cell,)}
cell <: Union{<:AbstractRecurrentCell, <:AbstractDebugRecurrentCell}
ordering <: AbstractTimeSeriesDataBatchOrdering
end

ConstructionBase.constructorof(::Type{<:Recurrence{R}}) where {R} = Recurrence{R}

function Recurrence(cell; ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex(),
return_sequence::Bool=false)
return Recurrence{return_sequence, typeof(cell), typeof(ordering)}(cell, ordering)
return Recurrence{return_sequence}(cell, ordering)
end

_eachslice(x::AbstractArray, ::TimeLastIndex) = _eachslice(x, Val(ndims(x)))
Expand Down Expand Up @@ -164,9 +172,8 @@ update the state with `Lux.update_state(st, :carry, nothing)`.
+ `cell`: Same as `cell`.
+ `carry`: The carry state of the `cell`.
"""
struct StatefulRecurrentCell{C <: AbstractRecurrentCell} <:
AbstractExplicitContainerLayer{(:cell,)}
cell::C
@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)}
cell <: Union{<:AbstractRecurrentCell, <:AbstractDebugRecurrentCell}
end

function initialstates(rng::AbstractRNG, r::StatefulRecurrentCell)
Expand Down Expand Up @@ -641,3 +648,69 @@ function Base.show(io::IO, g::GRUCell{use_bias, TS}) where {use_bias, TS}
TS && print(io, ", train_state=true")
return print(io, ")")
end

"""
BidirectionalRNN(cell::AbstractRecurrentCell,
backward_cell::Union{AbstractRecurrentCell, Nothing}=nothing;
merge_mode::Union{Function, Nothing}=vcat,
ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex())

Bidirectional RNN wrapper.

## Arguments

- `cell`: A recurrent cell. See [`RNNCell`](@ref), [`LSTMCell`](@ref), [`GRUCell`](@ref),
for how the inputs/outputs of a recurrent cell must be structured.
- `backward_cell`: A optional backward recurrent cell. If `backward_cell` is `nothing`,
the rnn layer instance passed as the `cell` argument will be used to generate the
backward layer automatically. `in_dims` of `backward_cell` should be consistent with
`in_dims` of `cell`

## Keyword Arguments

- `merge_mode`: Function by which outputs of the forward and backward RNNs will be combined.
default value is `vcat`. If `nothing`, the outputs will not be combined.
- `ordering`: The ordering of the batch and time dimensions in the input. Defaults to
`BatchLastIndex()`. Alternatively can be set to `TimeLastIndex()`.

## Inputs

- If `x` is a

+ Tuple or Vector: Each element is fed to the `cell` sequentially.

+ Array (except a Vector): It is spliced along the penultimate dimension and each
slice is fed to the `cell` sequentially.

## Returns

- Merged output of the `cell` and `backward_cell` for the entire sequence.
- Update state of the `cell` and `backward_cell`.

## Parameters

- `NamedTuple` with `cell` and `backward_cell`.

## States

- Same as `cell` and `backward_cell`.
"""
@concrete struct BidirectionalRNN <: AbstractExplicitContainerLayer{(:model,)}
model <: Parallel
end

(rnn::BidirectionalRNN)(x, ps, st::NamedTuple) = rnn.model(x, ps, st)
avik-pal marked this conversation as resolved.
Show resolved Hide resolved

function BidirectionalRNN(cell::AbstractRecurrentCell,
backward_cell::Union{AbstractRecurrentCell, Nothing}=nothing;
merge_mode::Union{Function, Nothing}=vcat,
ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex())
layer = Recurrence(cell; return_sequence=true, ordering)
backward_rnn_layer = backward_cell === nothing ? layer :
Recurrence(backward_cell; return_sequence=true, ordering)
fuse_op = merge_mode === nothing ? nothing : Broadcast.BroadcastFunction(merge_mode)
return BidirectionalRNN(Parallel(fuse_op;
forward_rnn=layer,
backward_rnn=Chain(;
rev1=ReverseSequence(), rnn=backward_rnn_layer, rev2=ReverseSequence())))
end
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,5 @@ end
@inline __set_refval!(x, y) = (x[] = y)

@inline __eltype(x) = eltype(x)

@inline __reverse(x; dims=:) = reverse(x; dims)
2 changes: 1 addition & 1 deletion test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

@test layer(x, ps, st)[1] == aType(xr)
@test layer(x2, ps, st)[1] == aType(x2rd1)
@test_throws DimensionMismatch layer2(x, ps2, st2)[1]
@test_throws ArgumentError layer2(x, ps2, st2)[1]
@test layer3(x, ps3, st3)[1] == aType(xr)
@test layer2(x2, ps2, st2)[1] == aType(x2rd2)

Expand Down
74 changes: 74 additions & 0 deletions test/layers/recurrent_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,77 @@ end
@test_throws ErrorException Lux._eachslice(x, BatchLastIndex())
end
end

@testitem "Bidirectional" timeout=3000 setup=[SharedTestSetup] tags=[:recurrent_layers] begin
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, device, ongpu) in MODES
@testset "cell: $_cell" for _cell in (RNNCell, LSTMCell, GRUCell)
cell = _cell(3 => 5)
bi_rnn = BidirectionalRNN(cell)
bi_rnn_no_merge = BidirectionalRNN(cell; merge_mode=nothing)
display(bi_rnn)

# Batched Time Series
x = randn(rng, Float32, 3, 4, 2) |> aType
ps, st = Lux.setup(rng, bi_rnn) .|> device
y, st_ = bi_rnn(x, ps, st)
y_, st__ = bi_rnn_no_merge(x, ps, st)

@jet bi_rnn(x, ps, st)
@jet bi_rnn_no_merge(x, ps, st)
avik-pal marked this conversation as resolved.
Show resolved Hide resolved

@test size(y) == (4,)
@test all(x -> size(x) == (10, 2), y)

@test length(y_) == 2
@test size(y_[1]) == size(y_[2])
@test size(y_[1]) == (4,)
@test all(x -> size(x) == (5, 2), y_[1])

__f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu

__f = p -> begin
(y1, y2), st_ = bi_rnn_no_merge(x, p, st)
return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2)
end
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu

@testset "backward_cell: $_backward_cell" for _backward_cell in (
RNNCell, LSTMCell, GRUCell)
cell = _cell(3 => 5)
backward_cell = _backward_cell(3 => 5)
bi_rnn = BidirectionalRNN(cell, backward_cell)
bi_rnn_no_merge = BidirectionalRNN(cell, backward_cell; merge_mode=nothing)
display(bi_rnn)

# Batched Time Series
x = randn(rng, Float32, 3, 4, 2) |> aType
ps, st = Lux.setup(rng, bi_rnn) .|> device
y, st_ = bi_rnn(x, ps, st)
y_, st__ = bi_rnn_no_merge(x, ps, st)

@jet bi_rnn(x, ps, st)
@jet bi_rnn_no_merge(x, ps, st)

@test size(y) == (4,)
@test all(x -> size(x) == (10, 2), y)

@test length(y_) == 2
@test size(y_[1]) == size(y_[2])
@test size(y_[1]) == (4,)
@test all(x -> size(x) == (5, 2), y_[1])

__f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu

__f = p -> begin
(y1, y2), st_ = bi_rnn_no_merge(x, p, st)
return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2)
end
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu
end
end
end
end
Loading