Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…g/v2.15.0)

PiperOrigin-RevId: 584135686
  • Loading branch information
romanngg committed Nov 21, 2023
1 parent aa3620d commit 429dc13
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
6 changes: 2 additions & 4 deletions neural_tangents/experimental/empirical_tf/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,8 @@
import tf2jax


# TODO(romann): update to PolymorphicFunction with tf 2.15 release
def empirical_ntk_fn_tf(
f: Union[tf.Module, tf.types.experimental.GenericFunction],
f: Union[tf.Module, tf.types.experimental.PolymorphicFunction],
trace_axes: Axes = (-1,),
diagonal_axes: Axes = (),
vmap_axes: VMapAxes = None,
Expand Down Expand Up @@ -244,8 +243,7 @@ def empirical_ntk_fn_tf(
if isinstance(f, tf.Module):
apply_fn, _ = get_apply_fn_and_params(f)

# TODO(romann): update to PolymorphicFunction with tf 2.15 release
elif isinstance(f, tf.types.experimental.GenericFunction):
elif isinstance(f, tf.types.experimental.PolymorphicFunction):
apply_fn = tf2jax.convert_functional(f, *f.input_signature)

else:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
INSTALL_REQUIRES = [
'jax>=0.4.14',
'frozendict>=2.3.8',
'tensorflow>=2.14.0',
'tensorflow>=2.15.0',
'tf2jax>=0.3.5',
]

Expand Down

0 comments on commit 429dc13

Please sign in to comment.