Skip to content

Commit

Permalink
Avoid deprecated ad.config & ad.source_info_util
Browse files Browse the repository at this point in the history
These were deprecated in JAX v0.4.19 and will be removed in a future release.

PiperOrigin-RevId: 606274588
  • Loading branch information
Jake VanderPlas authored and romanngg committed Feb 20, 2024
1 parent e91c80e commit e84d2ab
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions neural_tangents/_src/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
from jax.core import Var

from jax.extend import linear_util as lu
from jax.extend import source_info_util

from jax.interpreters import ad
from jax.interpreters.ad import UndefinedPrimal
Expand Down Expand Up @@ -1863,7 +1864,7 @@ def read_cotangent(v: Var) -> Union[jnp.ndarray, Zero]:
map(functools.partial(_write_primal, primal_env), jaxpr.invars, primals_in)

ct_env: dict[Var, jnp.ndarray] = {}
ctx = ad.source_info_util.transform_name_stack('transpose')
ctx = source_info_util.transform_name_stack('transpose')
with ctx:
map(functools.partial(_write_cotangent, 'outvars', ct_env),
jaxpr.outvars, cotangents_in)
Expand Down Expand Up @@ -2174,10 +2175,10 @@ def _eqn_vjp_fn(
# Identity function
return cts_in,

name_stack = (ad.source_info_util.current_name_stack() +
name_stack = (source_info_util.current_name_stack() +
eqn.source_info.name_stack)
with ad.source_info_util.user_context(eqn.source_info.traceback,
name_stack=name_stack):
with source_info_util.user_context(eqn.source_info.traceback,
name_stack=name_stack):
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
cts_in_avals = [v.aval for v in eqn.outvars]
params = dict(eqn.params)
Expand Down Expand Up @@ -2263,7 +2264,7 @@ def _write_cotangent(
return

ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct
if ad.config.jax_enable_checks:
if jax.config.jax_enable_checks:
ct_aval = core.get_aval(ct_env[v])
joined_aval = core.lattice_join(
v.aval, ct_aval).strip_weak_type().strip_named_shape()
Expand Down

0 comments on commit e84d2ab

Please sign in to comment.