Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoint not supported on Diagonal arrays #2275

kiranshila opened this issue Feb 25, 2024 · 1 comment

Adjoint not supported on Diagonal arrays #2275

kiranshila opened this issue Feb 25, 2024 · 1 comment
bug Something isn't working


Copy link

Describe the bug

Trying to compute the adjoint (A') of a Diagonal backed by a CuArray reverts to scalar indexing.

To reproduce

The Minimal Working Example (MWE) for this bug:

using CUDA, LinearAlgebra

x = CuArray(rand(ComplexF32,5))
A = Diagonal(x)
A' # Errors 

This will work in a REPL, but not running in VSCode (for some reason). The problem becomes more obvious when you try to use the result in computation, like a matmul with a dense array (this might also mean that diagonal multiplication has the issue, but I'm having difficulty determining that).

y  = CuArray(rand(ComplexF32,5,5))
y * A' # Errors

The adjoint in other contexts seems to work fine

x * x' # Works fine
y * y' # Works fine

And multiplication with non-adjointed diagonals work fine as well

y * A # Works fine

# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.1"
manifest_format = "2.0"
project_hash = "0ed3e6dccacf724ab2d9cf21733b6fc74eab6018"

uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.1"

uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.0+0"

deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.15.1"

deps = ["LibGit2"]
git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.9.3"

deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
version = "1.6.0"

uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"

deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.2.2"

deps = ["Artifacts", "Preferences"]
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.5.0"

deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
version = "0.6.4"

deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
version = "8.4.0+0"

deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"]
uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
version = "1.6.4+0"

deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
version = "1.11.0+1"

uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.27"

    LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
    LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables"
    LogExpFunctionsInverseFunctionsExt = "InverseFunctions"

    ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
    ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
    InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"

uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.28.2+1"

uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
version = "2023.1.10"

deps = ["OpenLibm_jll"]
git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "1.0.2"

uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0"

deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.23+4"

deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
version = "0.8.1+2"

deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.5+0"

deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.10.0"

deps = ["Preferences"]
git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f"
uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
version = "1.2.0"

deps = ["TOML"]
git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.4.1"

deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

deps = ["UUIDs"]
git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.3.0"

uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0"

uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "2.3.1"

    SpecialFunctionsChainRulesCoreExt = "ChainRulesCore"

    ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.9.3"

    StaticArraysChainRulesCoreExt = "ChainRulesCore"
    StaticArraysStatisticsExt = "Statistics"

    ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
    Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d"
uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
version = "1.4.2"

deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
version = "1.0.3"

deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
version = "1.10.0"

deps = ["DiffRules", "LinearAlgebra", "Requires"]
git-tree-sha1 = "6d476962ba4e435d7f4101a403b1d3d72afe72f3"
uuid = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
version = "0.3.7"

    TullioCUDAExt = "CUDA"
    TullioChainRulesCoreExt = "ChainRulesCore"
    TullioFillArraysExt = "FillArrays"
    TullioTrackerExt = "Tracker"

    CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
    ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
    FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
    Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.13+1"

deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+1"

deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
version = "1.52.0+1"

deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
version = "17.4.0+2"

Expected behavior

For Adjoints of diagonal arrays to match CPU behavior.

Version info

Details on Julia:

Julia Version 1.10.1
Commit 7790d6f0641 (2024-02-13 20:41 UTC)
Build Info:
  Official release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 20 × Intel(R) Core(TM) i9-9900X CPU @ 3.50GHz
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake-avx512)
Threads: 1 default, 0 interactive, 1 GC (on 20 virtual cores)
  JULIA_DEPOT_PATH = /home/kiran/.local/share/julia:/home/kiran/.local/share/julia:

Details on CUDA:

CUDA runtime 12.3, artifact installation
CUDA driver 12.3
NVIDIA driver 545.29.6

CUDA libraries: 
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+545.29.6

Julia packages: 
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0

- Julia: 1.10.1
- LLVM: 15.0.7

1 device:
  0: NVIDIA GeForce RTX 2080 Ti (sm_75, 9.607 GiB / 11.000 GiB available)
@kiranshila kiranshila added the bug Something isn't working label Feb 25, 2024
Copy link

maleadt commented Feb 27, 2024

This is the well-known issue of multiple array wrappers 'breaking' method dispatch, so I think this can be closed in favor of JuliaGPU/Adapt.jl#21.


This will work in a REPL, but not running in VSCode (for some reason).

It doesn't 'work' in the REPL, it's just that we allow scalar iteration with a warning in an interactive session. It looks like VSCode isn't recognized as such, resulting in scalar iteration generating a hard error.

Since it has proven tricky to solve this from the Julia side (see the linked issue, which point to a couple of Base PRs that have stranded) I'm considering switching to unified memory by default such that scalar fallbacks at least perform somewhat decently, but they still wouldn't be executing on the GPU.

@maleadt maleadt closed this as completed Feb 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
bug Something isn't working
None yet

No branches or pull requests

2 participants