-
Notifications
You must be signed in to change notification settings - Fork 467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[question] Seeking information on low-level TPU interaction and libtpu.so API #7803
Comments
@will-cromar should be able to share some information. |
All three frameworks interact with Almost all of our interactions with PJRT are in this folder, and it's largely independent from PyTorch itself: https://github.com/pytorch/xla/tree/master/torch_xla/csrc/runtime Specifically, to create a PJRT TPU client, you would need to go through the xla/torch_xla/csrc/runtime/pjrt_registry.cc Lines 118 to 126 in dd3b00c
Once you have a client instantiated, then your interactions are going to look a lot like this example from JAX: https://github.com/google/jax/blob/main/examples/jax_cpp/main.cc We use the PJRT C++ API direcly, but it's worth noting that (other than the example above) JAX actually mainly interacts with PJRT through Python bindings. I not nearly as familiar with those, so you'll have better luck asking in their repository if you want to use the same bindings. The framework code outside of |
I'm looking to build an automatic differentiation library for TPUs without using high-level front-ends like TensorFlow/JAX/PyTorch-XLA, but I'm finding information about lower-level TPU usage is practically non-existent.
Specifically, I'm interested in:
Are there any insights or suggestions on how to approach this, particularly regarding TPU support? Any ideas or help would be greatly appreciated.
I understand that some of this information might be proprietary, but any guidance on what is possible or available would be very helpful.
The text was updated successfully, but these errors were encountered: