This library is a modular framework for simulating forward models of cryo electron microscopy images. It is designed with 2D template matching analysis in mind, but it can be used generally. cryojax
is, of course, built on jax. It also uses equinox for modeling building, so equinox
functionality is supported in cryojax
.
The core of this package is its ability to simulate cryo-EM images. Starting with a 3D electron density map, one can simulate a scattering process onto the imaging plane with modulation by the instrument optics. Images are then sampled from models of the noise or the corresponding log-likelihood is computed.
These models can be fed into standard sampling, optimization, and model building libraries in jax
, such as blackjax, optax, or numpyro. The jax
ecosystem is rich and growing fast!
Installing cryojax
is simple. To start, I recommend creating a new virtual environment. For example, you could do this with conda
.
conda create -n cryojax -c conda-forge python=3.10
Note that python>=3.10
is required due to recent features in dataclasses
. Now, install JAX with either CPU or GPU support.
Finally, install cryojax
. For now, only a source build is supported.
git clone https://github.com/mjo22/cryojax
cd cryojax
python -m pip install .
This will install the remaining dependencies, such as equinox for jax-friendly dataclasses, jaxlie for coordinate rotations and translations, mrcfile for I/O, and dataclasses-json for serialization.
The jax-finufft package is an optional dependency used for non-uniform fast fourier transforms. These are included as an option for computing image projections. In this case, we recommend first following the jax_finufft
installation instructions and then installing cryojax
.
Please note that this library is currently experimental and the API is subject to change! The following is a basic workflow to generate an image with a gaussian white noise model.
First, instantiate the image formation method ("scattering") and its respective representation of an electron density ("specimen").
import jax
import jax.numpy as jnp
import cryojax.simulator as cs
template = "example.mrc"
scattering = cs.FourierSliceScattering(shape=(320, 320))
density = cs.ElectronGrid.from_file(template)
Here, template
is a 3D electron density map in MRC format. This could be taken from the EMDB, or rasterized from a PDB. cisTEM provides an excellent rasterization tool in its image simulation program. In the above example, a voxel electron density in fourier space is loaded and the fourier-slice projection theorem is initialized. We can now intstantiate the biological Specimen
.
specimen = cs.Specimen(density, resolution=1.1)
This is a container for the parameters and metadata stored in the electron density, along with additional parameters such as the rasterization resolution
.
Next, the model is configured for a given realization of the specimen. Here, Pose
, Optics
, and Detector
models and their respective parameters are initialized. These are stored in the PipelineState
container.
key = jax.random.PRNGKey(seed=0)
pose = cs.EulerPose(view_phi=0.0, view_theta=0.0, view_psi=0.0)
optics = cs.CTFOptics(defocus_u=10000.0, defocus_v=9800.0, defocus_angle=10.0)
detector = cs.GaussianDetector(key=key, pixel_size=1.1, variance=cs.Constant(1.0))
state = cs.PipelineState(pose=pose, optics=optics, detector=detector)
Then, an ImagePipeline
model is chosen. Here, we choose GaussianImage
.
model = cs.GaussianImage(scattering=scattering, specimen=specimen, state=state)
image = model()
This computes an image using the noise model of the detector (under the hood model.sample()
is called). One can also compute an image without the stochastic part of the model.
image = model.render()
Imaging models also accept a series of Filter
s and Mask
s. For example, one could add a LowpassFilter
, WhiteningFilter
, and a CircularMask
.
filters = [cs.LowpassFilter(scattering.padded_shape, cutoff=1.0), # Cutoff modes above Nyquist frequency
cs.WhiteningFilter(scattering.padded_shape, micrograph=micrograph)]
masks = [cs.CircularMask(scattering.shape, radius=1.0)] # Cutoff pixels above radius equal to (half) image size
model = cs.GaussianImage(scattering=scattering, specimen=specimen, state=state, filters=filters, masks=masks)
image = model()
If a GaussianImage
is initialized with the field observed
, the model will instead compute the log likelihood.
model = cs.GaussianImage(scattering=scattering, specimen=specimen, state=state, observed=observed)
log_likelihood = model()
Under the hood, this calls model.log_probability()
. Note that the user may need to do preprocessing of observed
, such as applying the relevant Filter
s and Mask
s.
Additional components can be plugged into the ImagePipeline
model's PipelineState
. For example, Ice
and electron beam Exposure
models are supported. For example, GaussianIce
models the ice as gaussian noise, and UniformExposure
multiplies the image by a scale factor. Imaging models from different stages of the pipeline are also implemented. ScatteringImage
computes images solely with the scattering model, while OpticsImage
uses a scattering and optics model. DetectorImage
turns this into a detector readout, while GaussianImage
adds the ability to evaluate a gaussian likelihood.
For these more advanced examples, see the tutorials section of the repository. In general, cryojax
is designed to be very extensible and new models can easily be implemented.
In jax
, we ultimately want to build a loss function and apply functional transformations to it. Assuming we have already globally configured our model components at our desired initial state, the below creates a loss function at an updated set of parameters. First, we must update the model.
@jax.jit
def update_model(model: cs.GaussianImage, params: dict[str, jax.Array]) -> cs.GaussianImage:
"""
Update the model with equinox.tree_at (https://docs.kidger.site/equinox/api/manipulation/#equinox.tree_at).
"""
where = lambda model: (model.state.pose.view_phi, model.state.optics.defocus_u, model.state.detector.pixel_size)
updated_model = eqx.tree_at(where, model, (params["view_phi"], params["defocus_u"], params["pixel_size"]))
return updated_model
We can now create the loss and differentiate it with respect to the parameters.
from functools import partial
@jax.jit
@partial(jax.value_and_grad, argnums=1)
def loss(model: cs.GaussianImage, params: dict[str, jax.Array]) -> jax.Array:
model = update_model(model, params)
return model.log_probability()
Finally, we can evaluate an updated set of parameters.
params = dict(view_phi=jnp.asarray(jnp.pi), defocus_u=jnp.asarray(9000.0), pixel_size=jnp.asarray(1.30))
log_likelihood, grad = loss(model, params)
To summarize, this example creates a loss function at an updated set of Pose
, Optics
, and Detector
parameters. Note that the PipelineState
contains all of the model parameters in this example. In general, any cryojax
Module
may contain model parameters. One gotcha is just that the ScatteringConfig
, Filter
s, and Mask
s all do computation upon initialization, so they should not be explicitly instantiated in the loss function evaluation. Another gotcha is that if the model
is not passed as an argument to the loss, there may be long compilation times because the electron density will be treated as static. This may result in slight speedups.
In general, there are many ways to write loss functions. See the equinox documentation for more use cases.
- Imaging models in
cryojax
supportjax
functional transformations, such as automatic differentiation withgrad
, paralellization withvmap
andpmap
, and just-in-time compilation withjit
. Models also support GPU/TPU acceleration. cryojax.Module
s, includingImagePipeline
models, are JSON serializable thanks to the packagedataclasses-json
. The methodModule.dumps
serializes the object as a JSON string, andModule.loads
instantiates it from the string. For example, write a model to disk withmodel.dump("model.json")
and instantiate it withcs.GaussianImage.load("model.json")
.- A
cryojax.Module
is just anequinox.Module
with added serialization functionality. Therefore, the entireequinox
ecosystem is available for usage!
-
cisTEM: A software to process cryo-EM images of macromolecular complexes and obtain high-resolution 3D reconstructions from them. The recent experimental release of
cisTEM
has implemented a successful 2DTM program. -
BioEM: Bayesian inference of Electron Microscopy. This codebase calculates the posterior probability of a structural model given multiple experimental EM images.