Skip to content

Utilities for neural signed distance fields in JAX.

License

Notifications You must be signed in to change notification settings

niklasschmitz/sdf_jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

93 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

sdf_jax

Utilities for neural signed distance fields in JAX.

Content

sdf_jax
    ├── discretize.py     # utils for dense 2D and 3D grid evaluation of a field
    ├── examples.py       # for debugging: simple analytical SDFs like the sphere
    ├── hash_encoding.py  # Multiresolution Hash Encoding
    └── util.py           # plotting utils for level-sets from marching cubes

The Multiresolution Hash Encoding in sdf_jax/hash_encoding.py implements the method described in

Instant Neural Graphics Primitives with a Multiresolution Hash Encoding
Thomas Müller, Alex Evans, Christoph Schied, Alexander Keller
ACM Transactions on Graphics (SIGGRAPH), July 2022
Website / Paper / Code / Video / BibTeX

Usage

Below is an example of how to wrap the Hash Encoding inside a treex layer:

from sdf_jax import hash_encoding
import jax.numpy as jnp
import jax.random as jrandom
import treex as tx

class HashEmbedding(tx.Module):
    theta: jnp.ndarray = tx.Parameter.node()

    def __init__(
        self, 
        levels: int=16, 
        hashmap_size_log2: int=14, 
        features_per_entry: int=2,
        nmin: int=16,
        nmax: int=512,
    ):
        self.levels = levels
        self.hashmap_size_log2 = hashmap_size_log2
        self.features_per_entry = features_per_entry
        self.nmin = nmin
        self.nmax = nmax

    def __call__(self, x):
        assert x.ndim == 1
        if self.initializing():
            hashmap_size = 1 << self.hashmap_size_log2
            key = tx.next_key()
            self.theta = jrandom.uniform(
                key, 
                (self.levels, hashmap_size, self.features_per_entry), 
                minval=-0.0001, 
                maxval=0.0001
            )
        
        y = hash_encoding.encode(x, self.theta, self.nmin, self.nmax)
        return y.reshape(-1)

x = jnp.ones(3)
emb = HashEmbedding().init(key=42, inputs=x)
print(emb(x).shape) # (32,) which is (levels * features_per_entry,)

Installation

To ensure reproducibility, to install this repo and its dev dependencies:

  1. Use Poetry. Make sure you have a local installation of Python >=3.8 (e.g. by running pyenv local 3.X.X) and run:

    poetry install 
  2. Alternatively, I've also included a requirements.txt that was generated from the pyproject.toml and poetry.lock files.

About

Utilities for neural signed distance fields in JAX.

Resources

License

Stars

Watchers

Forks