Skip to content
/ anynp Public

Proof-of-concept of global switching between numpy/jax/pytorch in a library.

Notifications You must be signed in to change notification settings

srush/anynp

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

anynp

Proof-of-concept of global switching between numpy/jax/pytorch in a library.

This is a wrapper around the Array API and the array_api_compat lib. It adds a stub for the np type so that mypy doesn't complain and a context handler.

from switcher import switch, XP
from numpy.typing import ArrayLike 

def my_fun(x) -> ArrayLike:
    array = switch.xp.asarray(x)
    print(type(array))
    return array

x = my_fun([10])

with switch.set_context(XP.Jax):
    x = my_fun([10])
    print(x.at[0].set(0))

with switch.set_context(XP.Torch):
    my_fun([10])

my_fun([20])

About

Proof-of-concept of global switching between numpy/jax/pytorch in a library.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages