Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NOT FOR MERGE] Rwitten host offload demo #535

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
host offload demo (crashing)
  • Loading branch information
Rafi Witten committed Mar 20, 2024
commit df2e1dd75fc4c4a4cb7e1befc4c505d3a5a29eee
68 changes: 68 additions & 0 deletions pedagogical_examples/host_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/python3

"""
Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

'''
Frequently folks want to offload tensors to the CPU. As of March 2024, that is nicely supported in Jax via sharding annotations.
'''

from functools import partial

from absl import app
from absl import flags
import jax
from jax.sharding import PartitionSpec
from jax.sharding import Mesh
from jax.experimental import mesh_utils
from jax._src.pjit import with_sharding_constraint

import argparse
import datetime
import numpy as np
from typing import Sequence

jax.config.update('jax_enable_memories', True)

devices = mesh_utils.create_device_mesh((jax.device_count(),))
mesh_axis_name = "axis"
global_mesh = Mesh(devices, (mesh_axis_name))
array_on_device_sharding = jax.sharding.NamedSharding(global_mesh, jax.sharding.PartitionSpec("axis"))

data_dim = 16384
num_tensors = 4

@partial(jax.jit, out_shardings=array_on_device_sharding )
def generate_array():
return jax.numpy.ones( (data_dim, data_dim), dtype = jax.numpy.bfloat16)

data = [generate_array() for i in range(num_tensors)]
shardings = jax.tree.map(lambda x : x.sharding, data)

host_out_shardings = jax.tree.map(lambda x : x.with_memory_kind('pinned_host'), shardings)
device_out_shardings = jax.tree.map(lambda x : x.with_memory_kind('device'), shardings)


@partial(jax.jit, out_shardings = host_out_shardings)
def put_to_host(x):
return x

@partial(jax.jit, out_shardings = device_out_shardings)
def put_to_device(x):
return x

host_data = put_to_host(data)
device_data = put_to_device(host_data)