Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chain rules for DCT #273

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open

chain rules for DCT #273

wants to merge 16 commits into from

Conversation

vpuri3
Copy link

@vpuri3 vpuri3 commented Jun 26, 2023

address #272

@codecov
Copy link

codecov bot commented Jun 26, 2023

Codecov Report

Attention: 18 lines in your changes are missing coverage. Please review.

Comparison is base (ef8fc5b) 73.08% compared to head (26df888) 70.70%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #273      +/-   ##
==========================================
- Coverage   73.08%   70.70%   -2.38%     
==========================================
  Files           5        6       +1     
  Lines         535      553      +18     
==========================================
  Hits          391      391              
- Misses        144      162      +18     
Files Coverage Δ
src/FFTW.jl 85.71% <ø> (ø)
ext/FFTWChainRulesCoreExt.jl 0.00% <0.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@vpuri3
Copy link
Author

vpuri3 commented Jun 26, 2023

This error is only happening with the MKL provider. With MKL, FFTW.jl doesn't even compile on my machine. Could be due to 008bc5b?

test_frule: idct on Array{Float64, 3},Int64: Error During Test at /home/runner/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl:123
  Got exception outside of a @test
  FFTW could not create plan
  Stacktrace:
    [1] error(s::String)
      @ Base ./error.jl:35
    [2] macro expansion
      @ FFTW ~/work/FFTW.jl/FFTW.jl/src/fft.jl:722 [inlined]

@vpuri3
Copy link
Author

vpuri3 commented Jul 8, 2023

@devmotion, could you please review this? LMK if you want me to remove the requires part.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the idea of the PR? The design goal was to define the differentiation rules in AbstractFFTs via adjoint plans. Downstream packages were supposed to implement this new adjoint interface.

@vpuri3
Copy link
Author

vpuri3 commented Jul 8, 2023

Only the MKL tests are failing. Is there a regression with MKL?

@devmotion
Copy link
Member

Can you explain the idea of the PR? The design goal was to define the differentiation rules in AbstractFFTs via adjoint plans. Downstream packages were supposed to implement this new adjoint interface.

More concretely, #249 outlines the intended approach. FFTW was not supposed to define custom rules.

@vpuri3
Copy link
Author

vpuri3 commented Jul 8, 2023

This PR defines Chain Rules for FFTW.dct, idct since these methods are not defined in AbstractFFTs.

@devmotion
Copy link
Member

Ref: JuliaMath/AbstractFFTs.jl#56

@vpuri3
Copy link
Author

vpuri3 commented Jul 8, 2023

With #249 , gradient computation would error for DCT/IDCT. This is because dct(x) is defined as dct(x) = plan_dct(x) * x, and Zygote ends up going inside plan_dct where it encounters some error. This is why we need an rrule for dct.

julia> using FFTW, Zygote                                                                                                                                                                                       
                                                                                                                                                                                                                
julia>                                                                                                                                                                                                          
julia> using LinearAlgebra, FFTW, Zygote                                                                                                                                                                        
                                                                                                                                                                                                                
julia> x = rand(4)                                                                                                                                                                                              
4-element Vector{Float64}:                                                                                                                                                                                      
 0.8692266334693106                                                                                                                                                                                             
 0.6938635624794242                                                                                                                                                                                             
 0.552208368655668                                                                                                                                                                                              
 0.9197557963740512                                                                                                                                                                                             
                                                                                                                                                                                                                
julia> f(x) = x |> dct |> idct |> norm                                                                                                                                                                          
f (generic function with 1 method)                                                                                                                                                                              
                                                                                                                                                                                                                
julia> f(x)                                                                                                                                                                                                     
1.5452787421840921                                                                                                                                                                                              
                                                                                                                                                                                                                
julia> Zygote.gradient(f, x)                                                                                                                                                                                    
ERROR: Compiling Tuple{Type{FFTW.r2rFFTWPlan{Float64, Any, false, 1}}, Vector{Float64}, FFTW.FakeArray{Float64, 1}, UnitRange{Int64}, Int64, UInt32, Float64}: try/catch is not supported.                      
Refer to the Zygote documentation for fixes.        
https://fluxml.ai/Zygote.jl/latest/limitations                                                                                                                                                                  

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, somehow I missed yesterday that the PR does not add rules for plans but for the dct, idct and r2r functions 🤦

As long as they/their interface is not moved to AbstractFFTs, rules should be defined here 👍

Project.toml Outdated Show resolved Hide resolved
ext/FFTWChainRulesCoreExt.jl Outdated Show resolved Hide resolved
ext/FFTWChainRulesCoreExt.jl Outdated Show resolved Hide resolved

# R2R

function ChainRulesCore.frule(Δ, ::typeof(r2r), x::AbstractArray, region...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the rrule for r2r is missing?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The R2R transforms are not unitary. There is some scaling involved that depends on the kind of R2R transform. Because it looks like an involved task, I chose to skip that for now. I am happy to look into that in a separate PR

src/FFTW.jl Outdated Show resolved Hide resolved
src/FFTW.jl Outdated Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
@vpuri3
Copy link
Author

vpuri3 commented Jul 10, 2023

@devmotion I've addressed all your comments. LMK if you have more questions :D

@vpuri3
Copy link
Author

vpuri3 commented Jul 15, 2023

@devmotion ping :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants