-
Notifications
You must be signed in to change notification settings - Fork 244
/
shmap_collective_matmul.py
209 lines (163 loc) · 6.72 KB
/
shmap_collective_matmul.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#!/usr/bin/python3
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""This script is an example collective matmul."""
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
import numpy as np
from functools import partial
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
from jax.experimental.shard_map import shard_map
MESH_DATA_AXIS = "dp"
MESH_FSDP_AXIS = "fsdp"
MESH_TENSOR_AXIS = "tp"
d = jax.devices()
outd = [[[d[0], d[1], d[3], d[2]]]]
global_mesh = Mesh(outd, (MESH_DATA_AXIS, MESH_FSDP_AXIS, MESH_TENSOR_AXIS))
print(global_mesh.shape)
batch_size = 2
seq_len = 8192
n_heads = 128
head_dim = 128
emb_dim = 16384
import random
import string
import datetime
def simple_timeit(f, *args, tries=10, trace_base_dir=None, task=None):
"""Simple utility to time a function for multiple runs"""
assert task is not None
trace_name = f"t_{task}_" + "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
if trace_base_dir:
trace_dir = f"{trace_base_dir}/{trace_name}"
else:
trace_dir = None
outcomes_ms = []
jax.block_until_ready(f(*args)) # warm it up!
if trace_dir:
jax.profiler.start_trace(trace_dir)
for _ in range(tries):
s = datetime.datetime.now()
jax.block_until_ready(f(*args))
e = datetime.datetime.now()
outcomes_ms.append(1000 * (e - s).total_seconds())
if trace_dir:
jax.profiler.stop_trace()
average_time_ms = sum(outcomes_ms) / len(outcomes_ms)
print(f"{task}: average time milliseconds: {average_time_ms:.2f}")
return average_time_ms
# gen data
def gen_data_fn():
key = jax.random.PRNGKey(np.random.randint(0, 256))
activations = jax.random.normal(key, shape=(batch_size, seq_len, emb_dim), dtype=jnp.bfloat16) # pylint: disable=redefined-outer-name
weights = jax.random.normal(key, shape=(emb_dim, n_heads, head_dim), dtype=jnp.bfloat16) # pylint: disable=redefined-outer-name
return activations, weights
data_fn = pjit(
gen_data_fn,
out_shardings=(P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None), P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None)),
)
def matmul(activations, weights): # pylint: disable=redefined-outer-name
return jnp.einsum("bsE,Ehd->bshd", activations, weights)
jit_matmul = pjit(matmul, out_shardings=P(MESH_FSDP_AXIS, None, MESH_TENSOR_AXIS, None))
@partial(
shard_map,
mesh=global_mesh,
in_specs=(
P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None),
P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None),
),
out_specs=P(MESH_FSDP_AXIS, None, MESH_TENSOR_AXIS, None),
check_rep=False,
)
def collective_matmul(activations, weights): # pylint: disable=redefined-outer-name
"""Collective matrix multiply"""
print(f"sh_map {activations.shape=} {weights.shape=}")
axis_size = jax.lax.psum(1, axis_name=MESH_TENSOR_AXIS)
axis_index = jax.lax.axis_index(axis_name=MESH_TENSOR_AXIS)
# The current sequence chunk
chunk_size = activations.shape[1]
mid_chunk = chunk_size // 2
# create accum buffer
accum = jnp.zeros(
(
activations.shape[0],
activations.shape[1] * axis_size,
weights.shape[-2],
weights.shape[-1],
),
dtype=activations.dtype,
)
# compute first chunk
update = jnp.einsum("bsE,Ehd->bshd", activations, weights)
update_index = (0, axis_index * chunk_size, 0, 0)
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
activation_forward, activation_backward = jnp.split(activations, 2, axis=1)
activation_forward = jax.lax.ppermute(
activation_forward,
axis_name=MESH_TENSOR_AXIS,
perm=[(j, (j + 1) % axis_size) for j in range(axis_size)],
)
activation_backward = jax.lax.ppermute(
activation_backward,
axis_name=MESH_TENSOR_AXIS,
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)],
)
# split activations into chunks and send
def scanned_call(i, carrys):
accum, activation_forward, activation_backward = carrys
update_forward = jnp.einsum("bsE,Ehd->bshd", activation_forward, weights)
update_backward = jnp.einsum("bsE,Ehd->bshd", activation_backward, weights)
activation_forward = jax.lax.ppermute(
activation_forward,
axis_name=MESH_TENSOR_AXIS,
perm=[(j, (j + 1) % axis_size) for j in range(axis_size)],
)
activation_backward = jax.lax.ppermute(
activation_backward,
axis_name=MESH_TENSOR_AXIS,
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)],
)
forward_update_index = ((axis_index - i - 1) % axis_size) * chunk_size
backward_update_index = ((axis_index + i + 1) % axis_size) * chunk_size + mid_chunk
accum = jax.lax.dynamic_update_slice(accum, update_forward, (0, forward_update_index, 0, 0))
accum = jax.lax.dynamic_update_slice(accum, update_backward, (0, backward_update_index, 0, 0))
return (accum, activation_forward, activation_backward)
print(f"{accum.shape=}")
accum, _, _ = jax.lax.fori_loop(0, (axis_size - 1), scanned_call, (accum, activation_forward, activation_backward))
return accum
with global_mesh:
activations, weights = data_fn()
jax.block_until_ready(activations)
jax.block_until_ready(weights)
@jax.jit
def run_naive(_activations, _weights):
with jax.named_scope("naive_matmul"):
outputs = jit_matmul(_activations, _weights)
return outputs
@jax.jit
def run_collective(_activations, _weights):
with jax.named_scope("collective_matmul"):
manual_outputs = jax.jit(collective_matmul)(_activations, _weights)
return manual_outputs
naive_outputs = run_naive(activations, weights)
collective_outputs = run_collective(activations, weights)
print(f"input {activations.shape=} {activations.addressable_shards[0].data.shape=}")
print(f"input {weights.shape=} {weights.addressable_shards[0].data.shape=}")
print(f"naive_outputs {naive_outputs.shape=} {naive_outputs.addressable_shards[0].data.shape=}")
print(f"collective_outputs {collective_outputs.shape=} {collective_outputs.addressable_shards[0].data.shape=}")
assert jnp.allclose(naive_outputs, collective_outputs), "Two algorithms should match but don't"
simple_timeit(run_naive, activations, weights, task="naive")
simple_timeit(run_collective, activations, weights, task="collective")