Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[question] Seeking information on low-level TPU interaction and libtpu.so API #7803

Open
notlober opened this issue Aug 2, 2024 · 2 comments

Comments

@notlober
Copy link

notlober commented Aug 2, 2024

I'm looking to build an automatic differentiation library for TPUs without using high-level front-ends like TensorFlow/JAX/PyTorch-XLA, but I'm finding information about lower-level TPU usage is practically non-existent.

Specifically, I'm interested in:

  1. How to interact with TPUs at a lower level than what's typically exposed in TensorFlow
  2. Information about the libtpu.so library and its API
  3. Any resources or documentation on implementing custom TPU operations

Are there any insights or suggestions on how to approach this, particularly regarding TPU support? Any ideas or help would be greatly appreciated.

I understand that some of this information might be proprietary, but any guidance on what is possible or available would be very helpful.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Aug 2, 2024

@will-cromar should be able to share some information.

@will-cromar
Copy link
Collaborator

All three frameworks interact with libtpu through the PJRT plugin API. Most of the core API for PJRT is documented in comments here: https://github.com/openxla/xla/blob/main/xla/pjrt/pjrt_client.h

Almost all of our interactions with PJRT are in this folder, and it's largely independent from PyTorch itself: https://github.com/pytorch/xla/tree/master/torch_xla/csrc/runtime

Specifically, to create a PJRT TPU client, you would need to go through the PjRtCApiClient similar to this (device_type = "tpu", library_path = "/path/to/libtpu.so"):

const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin(
absl::AsciiStrToLower(device_type), plugin->library_path());
XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type));
auto create_options = plugin->client_create_options();
client = xla::GetCApiClient(
absl::AsciiStrToUpper(device_type),
{create_options.begin(), create_options.end()}, kv_store)
.value();
profiler::RegisterProfilerForPlugin(c_api);

Once you have a client instantiated, then your interactions are going to look a lot like this example from JAX: https://github.com/google/jax/blob/main/examples/jax_cpp/main.cc

We use the PJRT C++ API direcly, but it's worth noting that (other than the example above) JAX actually mainly interacts with PJRT through Python bindings. I not nearly as familiar with those, so you'll have better luck asking in their repository if you want to use the same bindings.

The framework code outside of libtpu.so is all open source. I'm happy to help if you have any questions about the PJRT C++ API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants