Skip to content

Commit

Permalink
Re-add the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 4, 2024
1 parent 541e130 commit 39e3703
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 22 deletions.
6 changes: 4 additions & 2 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,8 @@ function BidirectionalRNN(cell::AbstractRecurrentCell;
backward_rnn_layer = backward_cell === nothing ? layer :
Recurrence(backward_cell; return_sequence=true, ordering)
fuse_op = merge_mode === nothing ? nothing : Broadcast.BroadcastFunction(merge_mode)
return BidirectionalRNN(Parallel(
fuse_op, layer, Chain(ReverseSequence(), backward_rnn_layer, ReverseSequence())))
return BidirectionalRNN(Parallel(fuse_op;
forward_rnn=layer,
backward_rnn=Chain(;
rev1=ReverseSequence(), rnn=backward_rnn_layer, rev2=ReverseSequence())))
end
32 changes: 12 additions & 20 deletions test/layers/recurrent_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,17 +402,14 @@ end
@test size(y_[1]) == (4,)
@test all(x -> size(x) == (5, 2), y_[1])

if mode != "AMDGPU"
# gradients test failed after vcat
# __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
# @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu
__f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu

__f = p -> sum(
Base.Fix1(sum, abs2), first(first(bi_rnn_no_merge(x, p, st))))
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu
else
# This is just added as a stub to remember about this broken test
__f = p -> begin
(y1, y2), st_ = bi_rnn_no_merge(x, p, st)
return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2)
end
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu

@testset "backward_cell: $_backward_cell" for _backward_cell in (
RNNCell, LSTMCell, GRUCell)
Expand All @@ -421,7 +418,6 @@ end
bi_rnn = BidirectionalRNN(cell; backward_cell=backward_cell)
bi_rnn_no_merge = BidirectionalRNN(
cell; backward_cell=backward_cell, merge_mode=nothing)
println("BidirectionalRNN:")
display(bi_rnn)

# Batched Time Series
Expand All @@ -441,18 +437,14 @@ end
@test size(y_[1]) == (4,)
@test all(x -> size(x) == (5, 2), y_[1])

if mode != "AMDGPU"
# gradients test failed after vcat
# __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
# @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu
__f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu

__f = p -> sum(
Base.Fix1(sum, abs2), first(first(bi_rnn_no_merge(x, p, st))))
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu
else
# This is just added as a stub to remember about this broken test
@test_broken 1 + 1 == 1
__f = p -> begin
(y1, y2), st_ = bi_rnn_no_merge(x, p, st)
return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2)
end
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu
end
end
end
Expand Down

0 comments on commit 39e3703

Please sign in to comment.