Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounce averaging #854

Open
wants to merge 276 commits into
base: master
Choose a base branch
from
Open

Bounce averaging #854

wants to merge 276 commits into from

Conversation

unalmis
Copy link
Collaborator

@unalmis unalmis commented Feb 4, 2024

This PR is completed. An alternative PR with more intelligent math (see #1045 ) will likely replace it. We should keep/merge both for comparison though.

  • differentiable algorithm to compute bounce points.
  • Fixed some bugs with numpy compatibility.
  • differentiable algorithm to compute bounce integrals.
  • works with any numerical quadrature
  • Matched analytic theory.

After this pull request:

@unalmis unalmis linked an issue Feb 4, 2024 that may be closed by this pull request
Copy link
Contributor

github-actions bot commented Feb 4, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +9.64 +/- 6.35     | +4.95e-02 +/- 3.26e-02 |  5.63e-01 +/- 2.7e-02  |  5.14e-01 +/- 1.8e-02  |
 test_build_transform_fft_midres         |     +3.15 +/- 7.76     | +1.89e-02 +/- 4.65e-02 |  6.19e-01 +/- 4.3e-02  |  6.00e-01 +/- 1.7e-02  |
 test_build_transform_fft_highres        |     -1.45 +/- 2.46     | -1.47e-02 +/- 2.49e-02 |  9.98e-01 +/- 8.0e-03  |  1.01e+00 +/- 2.4e-02  |
 test_equilibrium_init_lowres            |     -0.35 +/- 4.77     | -1.30e-02 +/- 1.80e-01 |  3.77e+00 +/- 1.2e-01  |  3.78e+00 +/- 1.4e-01  |
 test_equilibrium_init_medres            |     +1.73 +/- 5.48     | +7.35e-02 +/- 2.33e-01 |  4.33e+00 +/- 1.8e-01  |  4.26e+00 +/- 1.5e-01  |
 test_equilibrium_init_highres           |     -3.70 +/- 4.69     | -2.24e-01 +/- 2.85e-01 |  5.84e+00 +/- 2.1e-01  |  6.07e+00 +/- 1.9e-01  |
 test_objective_compile_dshape_current   |     -5.37 +/- 4.94     | -2.17e-01 +/- 2.00e-01 |  3.82e+00 +/- 2.7e-02  |  4.04e+00 +/- 2.0e-01  |
 test_objective_compile_atf              |     -1.29 +/- 2.46     | -1.09e-01 +/- 2.08e-01 |  8.34e+00 +/- 1.1e-01  |  8.45e+00 +/- 1.8e-01  |
 test_objective_compute_dshape_current   |     +2.68 +/- 3.78     | +3.36e-05 +/- 4.73e-05 |  1.29e-03 +/- 2.5e-05  |  1.25e-03 +/- 4.0e-05  |
 test_objective_compute_atf              |     +1.68 +/- 3.86     | +7.19e-05 +/- 1.65e-04 |  4.34e-03 +/- 1.2e-04  |  4.27e-03 +/- 1.1e-04  |
 test_objective_jac_dshape_current       |     -2.08 +/- 5.23     | -8.05e-04 +/- 2.03e-03 |  3.79e-02 +/- 1.5e-03  |  3.87e-02 +/- 1.4e-03  |
 test_objective_jac_atf                  |     +0.24 +/- 3.52     | +4.57e-03 +/- 6.65e-02 |  1.89e+00 +/- 3.3e-02  |  1.89e+00 +/- 5.8e-02  |
 test_perturb_1                          |     -0.85 +/- 1.71     | -1.16e-01 +/- 2.32e-01 |  1.34e+01 +/- 1.2e-01  |  1.35e+01 +/- 2.0e-01  |
 test_perturb_2                          |     +0.93 +/- 2.35     | +1.78e-01 +/- 4.49e-01 |  1.93e+01 +/- 2.9e-01  |  1.91e+01 +/- 3.4e-01  |
 test_proximal_jac_atf                   |     -0.57 +/- 1.36     | -4.63e-02 +/- 1.10e-01 |  8.08e+00 +/- 4.8e-02  |  8.12e+00 +/- 9.9e-02  |
 test_proximal_freeb_compute             |     +1.96 +/- 0.80     | +3.50e-03 +/- 1.42e-03 |  1.82e-01 +/- 1.1e-03  |  1.79e-01 +/- 8.9e-04  |
 test_proximal_freeb_jac                 |     +2.70 +/- 1.67     | +1.98e-01 +/- 1.22e-01 |  7.53e+00 +/- 1.1e-01  |  7.33e+00 +/- 5.0e-02  |
 test_solve_fixed_iter                   |     -2.03 +/- 12.76    | -3.78e-01 +/- 2.38e+00 |  1.83e+01 +/- 1.5e+00  |  1.86e+01 +/- 1.9e+00  |

Copy link

codecov bot commented Feb 4, 2024

Codecov Report

Attention: Patch coverage is 91.08108% with 33 lines in your changes missing coverage. Please review.

Project coverage is 95.40%. Comparing base (13108f6) to head (714a8f0).
Report is 23 commits behind head on master.

Files Patch % Lines
desc/compute/bounce_integral.py 91.57% 30 Missing ⚠️
desc/backend.py 66.66% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #854      +/-   ##
==========================================
- Coverage   95.43%   95.40%   -0.03%     
==========================================
  Files          87       88       +1     
  Lines       22313    22679     +366     
==========================================
+ Hits        21294    21637     +343     
- Misses       1019     1042      +23     
Files Coverage Δ
desc/grid.py 92.86% <100.00%> (+0.18%) ⬆️
desc/backend.py 88.88% <66.66%> (-1.36%) ⬇️
desc/compute/bounce_integral.py 91.57% <91.57%> (ø)

... and 2 files with indirect coverage changes

@f0uriest
Copy link
Member

f0uriest commented Feb 5, 2024

See also #719 for some other related ideas.

Copy link
Collaborator

@rahulgaur104 rahulgaur104 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few suggestions. I'll add the analytical test soon. It's somewhere in my old laptop.

desc/compute/utils.py Outdated Show resolved Hide resolved
desc/compute/utils.py Outdated Show resolved Hide resolved
desc/compute/utils.py Outdated Show resolved Hide resolved
desc/compute/utils.py Outdated Show resolved Hide resolved
@dpanici
Copy link
Collaborator

dpanici commented Feb 14, 2024

  • Use splines for both rootfind and integration
  • Keep exact method you have here for comparison

@unalmis unalmis marked this pull request as ready for review February 18, 2024 12:03
desc/compute/utils.py Outdated Show resolved Hide resolved
desc/compute/utils.py Outdated Show resolved Hide resolved
tests/test_compute_utils.py Outdated Show resolved Hide resolved
@unalmis unalmis marked this pull request as draft February 18, 2024 21:43
Copy link
Collaborator

@dpanici dpanici left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spline num/den separately, evaluate integrals with quadratures

desc/compute/utils.py Outdated Show resolved Hide resolved
@unalmis unalmis marked this pull request as ready for review February 25, 2024 09:03
@unalmis
Copy link
Collaborator Author

unalmis commented Feb 25, 2024

I would like to use a Hermite spline for |B|, but it looks like interpax doesn't support that. Should I add that?

desc/compute/utils.py Outdated Show resolved Hide resolved
desc/backend.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@rahulgaur104 rahulgaur104 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have an idea for a simple tokamak test that we can use to test this PR. I'll add the details of that test and more comments soon.

desc/backend.py Outdated Show resolved Hide resolved
desc/backend.py Show resolved Hide resolved
desc/compute/_field.py Outdated Show resolved Hide resolved
tests/test_compute_utils.py Outdated Show resolved Hide resolved
@unalmis unalmis changed the base branch from master to clebsh_basis July 20, 2024 06:15
@unalmis
Copy link
Collaborator Author

unalmis commented Jul 22, 2024

could add one-liner jnp.where statement to zero out quadrature nodes at middle peak of W shaped wells. this option would resolve numerical instability for strong singularities, which could potentially cause a spike in profile

Base automatically changed from clebsh_basis to master August 8, 2024 17:46

import numpy as np
from interpax import CubicHermiteSpline, PPoly, interp1d
from jax.nn import softmax
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should import this in backend

from interpax import CubicHermiteSpline, PPoly, interp1d
from jax.nn import softmax
from matplotlib import pyplot as plt
from orthax.legendre import leggauss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure we could use scipy for this to avoid an extra dependency (we use scipy gauss legendre for the bootstrap stuff). Also orthax is still very WIP.

# Compute from analytic formula to avoid the issue of complex roots with small
# imaginary parts and to avoid nan in gradient.
r = func[c.shape[0]](*c[:-1], c[-1] - k, sentinel, eps, distinct)
distinct = distinct and c.shape[0] > 3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't you have 2 duplicate real roots for a quadratic?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but i removed the duplicate in _root_quadratic

desc/compute/bounce_integral.py Show resolved Hide resolved
First axis enumerates the coefficients of power series. Second axis
enumerates the splines along the field lines. Last axis enumerates the
polynomials that compose the spline along a particular field line.
B_z_ra_c : jnp.ndarray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isnt this just polyder(B_c)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but I don't wan't to recompute that for every integral of a particular pitch when i loop over bounce_integrate

return result


def bounce_integral(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the only "entry point" to this new code? I assume all the _* functions are only needed locally, but are the other non _* functions meant to be semi-public (ie used by developers?) Or is this the only one?

return _filter_distinct(r, sentinel, eps) if distinct else r


def _poly_der(c):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this the same as jnp.polyder?

Copy link
Collaborator Author

@unalmis unalmis Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could replace with

def jax_poly_der(c):
    return jnp.vectorize(jnp.polyder, signature="(m)->(n)")(c.T).T

Under jit, they are the same speed and memory. However, this orders of magnitude slower on numpy backend.

Copy link
Collaborator Author

@unalmis unalmis Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually testing again today looks like the bumpy backend does fine too. I only tested cpu though

def _poly_val(x, c):
"""Evaluate the set of polynomials ``c`` at the points ``x``.

Note this function is not the same as ``np.polynomial.polynomial.polyval(x,c)``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.polyval assumes c is in decreasing order of powers, while np.polynomial.polynomial.polyval assumes c is increasing powers of x. In theory would jnp.polyval work here? (apart from it being possibly slower as it uses horners method which isn't great on gpu)

Copy link
Collaborator Author

@unalmis unalmis Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Each of these are equivalent. Speed and memory under jit on cpu shows all are similar. Method 1 is a little faster on cpu, and likely faster still on gpu. Method 1 is my preference

# method 1
val = jnp.linalg.vecdot(
    polyvander(x, c.shape[0] - 1), jnp.moveaxis(jnp.flipud(c), 0, -1)
)
# method 2
val = jnp.vectorize(jnp.polyval, signature="(m),(n)->(n)")(
    jnp.moveaxis(c, 0, -1), x[..., np.newaxis]
).squeeze(axis=-1)
# method 3
val = jnp.einsum(
    "...i,i...", x[..., np.newaxis] ** jnp.arange(c.shape[0] - 1, -1, -1), c
)

tests/test_bounce_integral.py Show resolved Hide resolved
tests/test_bounce_integral.py Show resolved Hide resolved
@f0uriest
Copy link
Member

Main question I forgot to include in the actual review comment:

There are a few different automorphisms (sin, arcsin) and different quadrature methods (gauss-legendre, gauss-chebyshev, tanh-sinh). Is there a clear winning combination? Or is it still problem dependent?

@PlasmaControl PlasmaControl deleted a comment from rahulgaur104 Aug 17, 2024
@unalmis unalmis added the hold merging master recent changes in master demand changes to this branch label Aug 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
hold merging master recent changes in master demand changes to this branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding bounce-averaging functionality
4 participants