Skip to content

Commit

Permalink
Updated jax.config import
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574912397
  • Loading branch information
superbobry authored and romanngg committed Nov 21, 2023
1 parent dce935b commit 109796e
Show file tree
Hide file tree
Showing 12 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion tests/batching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

from absl.testing import absltest
from jax import jit
from jax import config
from jax import random
from jax.config import config
import jax.numpy as jnp
from jax.tree_util import tree_map
import neural_tangents as nt
Expand Down
2 changes: 1 addition & 1 deletion tests/elementwise_numerical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from absl.testing import absltest

from jax.config import config
from jax import config
from examples import elementwise_numerical
from tests import test_utils

Expand Down
2 changes: 1 addition & 1 deletion tests/elementwise_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from absl.testing import absltest

from jax.config import config
from jax import config
from examples import elementwise
from tests import test_utils

Expand Down
2 changes: 1 addition & 1 deletion tests/empirical_ntk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from absl.testing import absltest

from jax.config import config
from jax import config
from examples import empirical_ntk
from tests import test_utils

Expand Down
2 changes: 1 addition & 1 deletion tests/empirical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from jax import random
from jax import remat
from jax import tree_map
from jax.config import config
from jax import config
import jax.numpy as jnp
from jax.tree_util import tree_reduce
import neural_tangents as nt
Expand Down
2 changes: 1 addition & 1 deletion tests/function_space_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from absl.testing import absltest

from jax.config import config
from jax import config
from examples import function_space
from tests import test_utils

Expand Down
2 changes: 1 addition & 1 deletion tests/imdb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Tests for `examples/imdb.py`."""

from absl.testing import absltest
from jax.config import config
from jax import config
from examples import imdb
from tests import test_utils

Expand Down
2 changes: 1 addition & 1 deletion tests/infinite_fcn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from absl.testing import absltest

from jax.config import config
from jax import config
from examples import infinite_fcn
from tests import test_utils

Expand Down
2 changes: 1 addition & 1 deletion tests/monte_carlo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from absl.testing import absltest
import jax
from jax import config
from jax import random
from jax.config import config
import jax.numpy as jnp
import neural_tangents as nt
from neural_tangents import stax
Expand Down
2 changes: 1 addition & 1 deletion tests/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from jax import jit
from jax import random
from jax import vmap
from jax.config import config
from jax import config
from jax.example_libraries import optimizers
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
Expand Down
2 changes: 1 addition & 1 deletion tests/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from absl.testing import absltest
import jax
from jax import lax
from jax.config import config
from jax import config
from jax.core import Primitive
from jax.core import ShapedArray
from jax.interpreters import ad
Expand Down
2 changes: 1 addition & 1 deletion tests/weight_space_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from absl.testing import absltest

from jax.config import config
from jax import config
from examples import weight_space
from tests import test_utils

Expand Down

0 comments on commit 109796e

Please sign in to comment.