-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement backend-agnostic jacobian trace for Flow Matching
- Loading branch information
Showing
3 changed files
with
120 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .jacobian_trace import jacobian_trace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import keras | ||
import numpy as np | ||
|
||
from bayesflow.types import Tensor | ||
|
||
|
||
def jacobian_trace(f: callable, x: Tensor, samples: int = 1) -> (Tensor, Tensor): | ||
""" | ||
Returns an unbiased estimate of the trace of the Jacobian of f, using Hutchinson's estimator. | ||
:param f: The function to be differentiated. | ||
Must take x as its only argument and return a single output Tensor. | ||
:param x: Tensor of shape (n, d) | ||
The input tensor to f. | ||
:param samples: The number of random samples to use for the estimate. | ||
If this exceeds the dimensionality of f(x) or you pass None, we | ||
will instead perform an exact computation which takes that many samples. | ||
Default: 1 | ||
:return: Tensor of shape (n,) | ||
An unbiased estimate of the trace of the Jacobian of f. | ||
""" | ||
|
||
batch_size, dims = keras.ops.shape(x) | ||
|
||
match keras.backend.backend(): | ||
case "jax": | ||
import jax | ||
|
||
fx, vjp_fn = jax.vjp(f, x) | ||
vjp_fn = jax.jit(vjp_fn) | ||
|
||
trace = keras.ops.zeros((batch_size,), dtype=x.dtype) | ||
|
||
# TODO: can we use jax.vmap to avoid the for loop? | ||
|
||
if samples is None or dims <= samples: | ||
# exact | ||
for dim in range(dims): | ||
projector = keras.ops.zeros((batch_size, dims), dtype=x.dtype) | ||
projector = projector.at[:, dim].set(1.0) | ||
|
||
vjp = vjp_fn(projector)[0] | ||
|
||
trace += vjp[:, dim] | ||
else: | ||
# estimate | ||
for sample in range(samples): | ||
projector = keras.random.normal((batch_size, dims), dtype=x.dtype) | ||
|
||
vjp = vjp_fn(projector)[0] | ||
|
||
trace += keras.ops.sum(vjp * projector, axis=1) | ||
|
||
case "tensorflow": | ||
import tensorflow as tf | ||
|
||
with tf.GradientTape(persistent=True) as tape: | ||
tape.watch(x) | ||
fx = f(x) | ||
|
||
trace = keras.ops.zeros((batch_size,)) | ||
|
||
# TODO: can we use tf.gradients to avoid the for loop? | ||
|
||
if samples is None or dims <= samples: | ||
# exact | ||
for dim in range(dims): | ||
projector = np.zeros((batch_size, dims), dtype=keras.backend.standardize_dtype(x.dtype)) | ||
projector[:, dim] = 1.0 | ||
projector = keras.ops.convert_to_tensor(projector) | ||
|
||
vjp = tape.gradient(fx, x, projector) | ||
|
||
trace += vjp[:, dim] | ||
else: | ||
# estimate | ||
for _ in range(samples): | ||
projector = keras.random.normal((batch_size, dims), dtype=x.dtype) | ||
|
||
vjp = tape.gradient(fx, x, projector) | ||
|
||
trace += keras.ops.sum(vjp * projector, axis=1) / samples | ||
case "torch": | ||
import torch | ||
|
||
with torch.enable_grad(): | ||
x.requires_grad = True | ||
fx = f(x) | ||
|
||
trace = keras.ops.zeros(keras.ops.shape(x)[0]) | ||
|
||
# TODO: can we use is_grads_batched to avoid the for loop? | ||
|
||
if samples is None or dims <= samples: | ||
# exact | ||
for dim in range(dims): | ||
projector = keras.ops.zeros((batch_size, dims), dtype=x.dtype) | ||
projector[:, dim] = 1.0 | ||
|
||
vjp = torch.autograd.grad(fx, x, projector, retain_graph=True)[0] | ||
|
||
trace += vjp[:, dim] | ||
else: | ||
# estimate | ||
for _ in range(samples): | ||
projector = keras.random.normal((batch_size, dims), dtype=x.dtype) | ||
|
||
vjp = torch.autograd.grad(fx, x, projector, retain_graph=True)[0] | ||
|
||
trace += keras.ops.sum(vjp * projector, axis=1) / samples | ||
case other: | ||
raise NotImplementedError(f"Jacobian trace computation is currently not supported for backend '{other}'.") | ||
|
||
return fx, trace |