Skip to content

Commit

Permalink
Save result from jax.local_device_count()
Browse files Browse the repository at this point in the history
Identical issue to google/nerfies#47
  • Loading branch information
JamesPerlman committed Feb 17, 2022
1 parent a371bc2 commit 1f87f3f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion hypernerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def points_bounding_size(points):
def shard(xs, device_count=None):
"""Split data into shards for multiple devices along the first dimension."""
if device_count is None:
jax.local_device_count()
device_count = jax.local_device_count()
return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)


Expand Down

0 comments on commit 1f87f3f

Please sign in to comment.