Skip to content

Commit

Permalink
Refactor batched_vec (#464)
Browse files Browse the repository at this point in the history
Branching on the type of the second argument caused a subtle performance bug when differentiating via `Zygote`; see #462
  • Loading branch information
jondeuce committed Jan 24, 2023
1 parent 16b7486 commit 7f6ea50
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,12 @@ julia> batched_vec(A,b) |> size
(16, 32)
```
"""
function batched_vec(A::AbstractArray{T,3} where T, B::AbstractMatrix)
# If B is transposed, then stride=1 is the batch dim, so we will end up copying anyway:
if B isa AdjOrTransAbsMat{<:BlasFloat, <:StridedMatrix}
return batched_vec(A, copy(B))
end
batched_vec(A::AbstractArray{T,3} where T, B::AbstractMatrix) =
reshape(batched_mul(A, reshape(B, size(B,1), 1, size(B,2))), size(A,1), size(A,3))
end

# If B is transposed, then stride=1 is the batch dim, so we will end up copying anyway:
batched_vec(A::AbstractArray{T,3} where T, B::AdjOrTransAbsMat{<:BlasFloat, <:StridedMatrix}) =
batched_vec(A, copy(B))

batched_vec(A::AbstractArray{T,3} where T, b::AbstractVector) =
reshape(batched_mul(A, reshape(b, length(b), 1, 1)), size(A,1), size(A,3))
Expand Down

0 comments on commit 7f6ea50

Please sign in to comment.