Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert back to standard FFT convention #242

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/examples/read-dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,14 @@
"\n",
"\n",
"fig, ax = plt.subplots(figsize=(4, 4))\n",
"n_pixels = relion_particle.instrument_config.n_pixels\n",
"spectrum, frequencies = compute_radially_averaged_powerspectrum(\n",
" fourier_image,\n",
" radial_frequency_grid_in_angstroms,\n",
" pixel_size,\n",
" maximum_frequency=1 / (2 * pixel_size),\n",
")\n",
"ax.plot(frequencies, spectrum, color=\"k\")\n",
"ax.plot(frequencies, spectrum / n_pixels, color=\"k\")\n",
"ax.set(\n",
" xlabel=\"frequency magnitude $[\\AA^{-1}]$\",\n",
" ylabel=\"radially averaged power spectrum\",\n",
Expand Down Expand Up @@ -392,7 +393,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.1.-1"
}
},
"nbformat": 4,
Expand Down
6 changes: 1 addition & 5 deletions src/cryojax/image/_downsample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Routines for downsampling arrays"""

import math
from typing import overload

import jax.numpy as jnp
Expand Down Expand Up @@ -123,11 +122,8 @@ def downsample_to_shape_with_fourier_cropping(
the downsampled array in fourier space assuming hermitian symmetry,
with the zero frequency component in the corner.
"""
n_pixels, new_n_pixels = image_or_volume.size, math.prod(downsampled_shape)
fourier_array = jnp.fft.fftshift(fftn(image_or_volume))
cropped_fourier_array = jnp.sqrt(n_pixels / new_n_pixels) * crop_to_shape(
fourier_array, downsampled_shape
)
cropped_fourier_array = crop_to_shape(fourier_array, downsampled_shape)
if get_real:
return ifftn(jnp.fft.ifftshift(cropped_fourier_array))
else:
Expand Down
8 changes: 4 additions & 4 deletions src/cryojax/image/_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def ifftn(
ift :
Inverse fourier transform.
"""
ift = jnp.fft.fftshift(jnp.fft.ifftn(ft, s=s, axes=axes, norm="ortho"), axes=axes)
ift = jnp.fft.fftshift(jnp.fft.ifftn(ft, s=s, axes=axes), axes=axes)

return ift

Expand All @@ -49,7 +49,7 @@ def fftn(
ft :
Fourier transform of array.
"""
ft = jnp.fft.fftn(jnp.fft.ifftshift(ift, axes=axes), s=s, axes=axes, norm="ortho")
ft = jnp.fft.fftn(jnp.fft.ifftshift(ift, axes=axes), s=s, axes=axes)

return ft

Expand All @@ -72,7 +72,7 @@ def irfftn(
ift :
Inverse fourier transform.
"""
ift = jnp.fft.fftshift(jnp.fft.irfftn(ft, s=s, axes=axes, norm="ortho"), axes=axes)
ift = jnp.fft.fftshift(jnp.fft.irfftn(ft, s=s, axes=axes), axes=axes)

return ift

Expand All @@ -94,6 +94,6 @@ def rfftn(
ft :
Fourier transform of array.
"""
ft = jnp.fft.rfftn(jnp.fft.ifftshift(ift, axes=axes), s=s, axes=axes, norm="ortho")
ft = jnp.fft.rfftn(jnp.fft.ifftshift(ift, axes=axes), s=s, axes=axes)

return ft
27 changes: 24 additions & 3 deletions src/cryojax/image/_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
Image normalization routines.
"""

import math
from typing import Optional

import jax.numpy as jnp
from jaxtyping import Array, Float, Inexact

Expand All @@ -11,6 +14,7 @@ def normalize_image(
*,
is_real: bool = True,
half_space: bool = True,
shape_in_real_space: Optional[tuple[int, int]] = None,
) -> Inexact[Array, "y_dim x_dim"]:
"""
Normalize so that the image is mean 0
Expand All @@ -22,6 +26,7 @@ def normalize_image(
0.0,
is_real=is_real,
half_space=half_space,
shape_in_real_space=shape_in_real_space,
)


Expand All @@ -32,6 +37,7 @@ def rescale_image(
*,
is_real: bool = True,
half_space: bool = True,
shape_in_real_space: Optional[tuple[int, int]] = None,
) -> Inexact[Array, "y_dim x_dim"]:
"""Normalize so that the image is mean mu
and standard deviation N in real space.
Expand Down Expand Up @@ -64,9 +70,24 @@ def rescale_image(
rescaled_image = std * normalized_image + mean
else:
N1, N2 = image.shape
n_pixels, n_modes = N1 * (2 * N2 - 1) if half_space else N1 * N2, N1 * N2
n_pixels = (
(
N1 * (2 * N2 - 1)
if shape_in_real_space is None
else math.prod(shape_in_real_space)
)
if half_space
else N1 * N2
)
image_with_zero_mean = image.at[0, 0].set(0.0)
image_std = jnp.linalg.norm(image_with_zero_mean) / jnp.sqrt(n_modes)
image_std = (
jnp.sqrt(
jnp.sum(jnp.abs(image_with_zero_mean[:, 0]) ** 2)
+ 2 * jnp.sum(jnp.abs(image_with_zero_mean[:, 1:]) ** 2)
)
if half_space
else jnp.linalg.norm(image_with_zero_mean)
) / n_pixels
normalized_image = image_with_zero_mean / image_std
rescaled_image = (normalized_image * std).at[0, 0].set(mean * jnp.sqrt(n_pixels))
rescaled_image = (normalized_image * std).at[0, 0].set(mean * n_pixels)
return rescaled_image
16 changes: 11 additions & 5 deletions src/cryojax/inference/distributions/_gaussian_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ def compute_noise(
]
):
pipeline = self.imaging_pipeline
n_pixels = pipeline.instrument_config.padded_n_pixels
freqs = pipeline.instrument_config.padded_frequency_grid_in_angstroms
# Compute the zero mean variance and scale up to be independent of the number of
# pixels
std = jnp.sqrt(self.variance_function(freqs))
std = jnp.sqrt(n_pixels * self.variance_function(freqs))
noise = pipeline.postprocess(
std
* jr.normal(rng_key, shape=freqs.shape[0:-1])
Expand Down Expand Up @@ -132,9 +133,10 @@ def log_likelihood(
- `observed` : The observed data in fourier space.
"""
pipeline = self.imaging_pipeline
n_pixels = pipeline.instrument_config.n_pixels
freqs = pipeline.instrument_config.frequency_grid_in_angstroms
# Compute the variance and scale up to be independent of the number of pixels
variance = self.variance_function(freqs)
variance = n_pixels * self.variance_function(freqs)
# Create simulated data
simulated = self.compute_signal(get_real=False)
# Compute residuals
Expand All @@ -147,9 +149,13 @@ def log_likelihood(
)
# Compute log-likelihood, throwing away the zero mode. Need to take care
# to compute the loss function in fourier space for a real-valued function.
log_likelihood = -1.0 * (
jnp.sum(log_likelihood_per_mode[1:, 0])
+ 2 * jnp.sum(log_likelihood_per_mode[:, 1:])
log_likelihood = (
-1.0
* (
jnp.sum(log_likelihood_per_mode[1:, 0])
+ 2 * jnp.sum(log_likelihood_per_mode[:, 1:])
)
/ n_pixels
)

return log_likelihood
2 changes: 1 addition & 1 deletion src/cryojax/simulator/_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _compute_expected_events_or_detector_readout(
electrons_per_image = N_pix * electrons_per_pixel
# Normalize the squared wavefunction to a set of probabilities
fourier_squared_wavefunction_at_detector_plane /= (
fourier_squared_wavefunction_at_detector_plane[0, 0] * jnp.sqrt(N_pix)
fourier_squared_wavefunction_at_detector_plane[0, 0]
)
# Compute the noiseless signal by applying the DQE to the squared wavefunction
fourier_signal = fourier_squared_wavefunction_at_detector_plane * jnp.sqrt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,5 +289,4 @@ def _extract_surface_from_voxel_grid(
# Set last line of frequencies to zero if image dimension is even
if N % 2 == 0:
projection = projection.at[:, -1].set(0.0 + 0.0j).at[N // 2, :].set(0.0 + 0.0j)
# Re-scale projection to account for "ortho" FFT normalization convention
return projection * N ** (1 / 2)
return projection
3 changes: 1 addition & 2 deletions src/cryojax/simulator/_potential_integrator/nufft_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,4 @@ def _project_with_nufft(weights, coordinate_list, shape, eps=1e-6):
projection = projection.at[:, -1].set(0.0 + 0.0j)
if M1 % 2 == 0:
projection = projection.at[M1 // 2, :].set(0.0 + 0.0j)
# Return projection in "ortho" FFT normalization convention
return projection / jnp.sqrt(M1 * M2)
return projection
3 changes: 2 additions & 1 deletion src/cryojax/simulator/_solvent.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ def sample_fourier_phase_shifts_from_ice(
Array, "{instrument_config.padded_y_dim} {instrument_config.padded_x_dim//2+1}"
]:
"""Sample a realization of the ice phase shifts as colored gaussian noise."""
n_pixels = instrument_config.padded_n_pixels
frequency_grid_in_angstroms = instrument_config.padded_frequency_grid_in_angstroms
# Compute standard deviation, scaling up by the variance by the number
# of pixels to make the realization independent pixel-independent in real-space.
std = jnp.sqrt(self.variance_function(frequency_grid_in_angstroms))
std = jnp.sqrt(n_pixels * self.variance_function(frequency_grid_in_angstroms))
ice_integrated_potential_at_exit_plane = std * jr.normal(
key,
shape=frequency_grid_in_angstroms.shape[0:-1],
Expand Down
4 changes: 2 additions & 2 deletions tests/test_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ def test_gaussian_limit():
np.testing.assert_allclose(
irfftn(
jnp.abs(fourier_gaussian_detector_readout) ** 2
/ (jnp.sqrt(n_pixels) * electrons_per_pixel**2),
/ (n_pixels * electrons_per_pixel**2),
s=config.padded_shape,
),
irfftn(
jnp.abs(fourier_poisson_detector_readout) ** 2
/ (jnp.sqrt(n_pixels) * electrons_per_pixel**2),
/ (n_pixels * electrons_per_pixel**2),
s=config.padded_shape,
),
rtol=1e-2,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ def test_fourier_vs_real_normalized_image(noisy_model):
normalize_image(
noisy_model.render(get_real=False),
is_real=False,
shape_in_real_space=im1.shape,
),
s=noisy_model.instrument_config.shape,
) # type: ignore
for im in [im1, im2]:
np.testing.assert_allclose(jnp.std(im), jnp.asarray(1.0), atol=1e-2)
np.testing.assert_allclose(jnp.std(im), jnp.asarray(1.0), rtol=1e-3)
np.testing.assert_allclose(jnp.mean(im), jnp.asarray(0.0), atol=1e-8)
Loading