diff --git a/MaxText/common_types.py b/MaxText/common_types.py index 2961104f3..e5e01ada7 100644 --- a/MaxText/common_types.py +++ b/MaxText/common_types.py @@ -31,11 +31,16 @@ ScanIn = partitioning.ScanIn AxisNames = tuple[str, ...] +AxisIdxes = tuple[int, ...] BATCH = "activation_batch" LENGTH = "activation_length" HEAD = "activation_heads" D_KV = "activation_kv" +CACHE_BATCH = "cache_batch" +CACHE_SEQUENCE = "cache_sequence" +CACHE_HEADS = "cache_heads" +CACHE_KV = "cache_kv" MODEL_MODE_AUTOREGRESSIVE = "autoregressive" MODEL_MODE_PREFILL = "prefill" diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index dda8f688e..b53a444c6 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -284,6 +284,14 @@ inference_microbenchmark_stages: "prefill,generate" inference_microbenchmark_loop_iters: 10 inference_microbenchmark_log_file_path: "" +# KV Cache layout control +# Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV +# Default layout: 1,2,0,3 ; CACHE_SEQUENCE, CACHE_HEADS, CACHE_BATCH, CACHE_KV +prefill_key_axis_order: "1,2,0,3" +prefill_value_axis_order: "1,2,0,3" +ar_key_axis_order: "1,2,0,3" +ar_value_axis_order: "1,2,0,3" + # Maxengine Metrics prometheus_port: 0 diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 61f116375..8e0938444 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -48,10 +48,15 @@ Quant = quantizations.AqtQuantization AxisNames = common_types.AxisNames +AxisIdxes = common_types.AxisIdxes BATCH = common_types.BATCH LENGTH = common_types.LENGTH HEAD = common_types.HEAD D_KV = common_types.D_KV +CACHE_BATCH = common_types.CACHE_BATCH +CACHE_SEQUENCE = common_types.CACHE_SEQUENCE +CACHE_HEADS = common_types.CACHE_HEADS +CACHE_KV = common_types.CACHE_KV DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) @@ -104,6 +109,11 @@ class AttentionOp(nn.Module): max_prefill_predict_length: int = -1 float32_logits: bool = False flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) + kv_cache_logical_layout: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV) + prefill_key_axis_order: AxisIdxes = (1, 2, 0, 3) + prefill_value_axis_order: AxisIdxes = (1, 2, 0, 3) + ar_key_axis_order: AxisIdxes = (1, 2, 0, 3) + ar_value_axis_order: AxisIdxes = (1, 2, 0, 3) dropout_rate: float = 0.0 dtype: DType = jnp.float32 quant: Optional[Quant] = None @@ -325,119 +335,166 @@ def qk_product(self, query: Array, key: Array) -> Array: """Query-Key product. Args: - query: Query projection, in shape of [b, t, n, d], where b: batch size, t: - query length, n: number of heads, d: project dimension. - key: Key projection in shape of [b, s, n_kv, d] for where s: key length, n_kv is - kv heads (sometimes k). The number of group for query is n // n_kv (sometimes g). + query: Query projection, in shape of [b, t, n, d] + key: Key projection in shape of [b, s, n_kv, d] Returns: - results in shape [b, n_kv, n // n_kv, t, s]. + results in shape [b, n_kv, n // n_kv, t, s]. + + Annotations: + b: batch size + t: query length + s: key / value length + d: head / kv dimension + n: number of query heads + n_kv: number of kv heads, sometimes annotated as k + n // n_kv: number of group for query, sometimes annotated with g """ b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d)) result = jnp.einsum("btkgd,bskd->bkgts", query, key) - return result # (4, 8, 1, 1, 6) + return result def wv_product(self, attn_weights: Array, value: Array) -> Array: """weighted value product. Args: - attn_weights: Computed results of qk_einsum, in shape [batch_size, num_kv_heads, group_size, q_len, k_len]. - value: Value projection, in shape of [batch_size, v_len, num_kv_heads, kv_dim]. + attn_weights: Computed results of qk_einsum, in shape [b, n_kv, n // n_kv, t, s] + value: Value projection, in shape of [b, s, n_kv, d] Returns: - result in shape [batch_size, q_len, num_kv_heads * group_size, kv_dim] + result in shape [b, t, n, d] + + Annotations: + b: batch size + t: query length + s: key / value length + d: head / kv dimension + n: number of query heads + n_kv: number of kv heads, sometimes annotated as k + n // n_kv: number of group for query, sometimes annotated with g """ out = jnp.einsum("bkgts,bskd->btkgd", attn_weights, value) b, t, n_kv, g, d = out.shape result = jnp.reshape(out, (b, t, n_kv * g, d)) return result - def revert_kvlen_axis(self, kv): - """Revert key/value length axis. + def revert_kv_cache(self, kv, cached_axis_order): + """Revert key/value cache to logical shape. Args: - kv: in shape [b, ..., n, d, s]. + kv: reshaped kv as defined in cached_axis_order Returns: - reshaped kv as [b, ..., s, n, d] + revert kv to logical shape as [b, s, n_kv, d] + + Annotations: + b: batch size + s: key / value length + n_kv: number of kv heads, sometimes annotated as k + d: head / kv dimension + """ - return jax.numpy.moveaxis(kv, (0, 1, 2, 3), (1, 2, 0, 3)) + return jax.numpy.moveaxis(kv, (0, 1, 2, 3), cached_axis_order) - def move_kvlen_axis(self, kv): - """Move key/value length axis to the end. + def reshape_kv_cache(self, kv, cached_axis_order): + """Reshape key/value cache as defined in cached_axis_order. Args: - kv: in shape [b, ..., s, n, d]. + kv: in logical shape as [b, s, n_kv, d] Returns: - reshaped kv as [b, ..., n, d, s] + reshaped kv as defined in cached_axis_order + + Annotations: + b: batch size + s: key / value length + n_kv: number of kv heads, sometimes annotated as k + d: head / kv dimension + """ - return jax.numpy.moveaxis(kv, (0, 1, 2, 3), (2, 0, 1, 3)) + axis_order_to_index_mapping = {a:i for i, a in enumerate(cached_axis_order)} + axis_destination = tuple([i for a, i in sorted(axis_order_to_index_mapping.items())]) + return jax.numpy.moveaxis(kv, (0, 1, 2, 3), axis_destination) - def cached_kv_shape(self, kv_shape): + def cached_kv_layout(self, kv_layout, cached_axis_order): + return tuple([kv_layout[i] for i in cached_axis_order]) + + def cached_kv_shape(self, kv_shape, cached_axis_order): """Cached KV shape. - The key and value have dimension [batch, length, num_heads, head_dim], but - we cache them as [length, num_heads, batch, head_dim, ] for optimized read/write performance. + The key and value have dimension [b, s, n_kv, d], but + we cache them as defined in cached_axis_order for optimized read/write performance. Args: - kv_shape: shape of key or value for caching, as [b, ..., s, n, d]. + kv_shape: shape of key or value for caching, as [b, s, n_kv, d]. Returns: - Swapped kv_shape as [b, ..., n, d, s] for cache. + Swapped kv_shape as defined in cached_axis_order for cache. + + Annotations: + b: batch size + s: key / value length + n_kv: number of kv heads, sometimes annotated as k + d: head / kv dimension + """ - return (kv_shape[1], kv_shape[2], kv_shape[0], kv_shape[3]) + return tuple([kv_shape[i] for i in cached_axis_order]) def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache): dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16 - kv_cache_layout = ( - "cache_sequence", - "cache_heads", - "cache_batch", - "cache_kv", - ) cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size) + key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_key_axis_order) + value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_value_axis_order) + + key_shape = self.cached_kv_shape(cache_logical_shape, self.prefill_key_axis_order) + value_shape = self.cached_kv_shape(cache_logical_shape, self.prefill_value_axis_order) + cached_key = self.variable( "cache", "cached_prefill_key", - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), + nn.with_logical_partitioning(jnp.zeros, key_layout), + key_shape, dtype, ) cached_value = self.variable( "cache", "cached_prefill_value", - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), + nn.with_logical_partitioning(jnp.zeros, value_layout), + value_shape, dtype, ) cached_segment_id = self.variable( "cache", "cache_prefill_segment_id", - nn.with_logical_partitioning(jnp.zeros, ("cache_batch", "cache_sequence")), + nn.with_logical_partitioning(jnp.zeros, (CACHE_BATCH, CACHE_SEQUENCE)), (cache_logical_shape[0], self.max_prefill_predict_length), jnp.int32, ) if self.quantize_kvcache: + cache_logical_shape_scale = (batch, self.max_prefill_predict_length, heads, 1) + + key_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.prefill_key_axis_order) + value_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.prefill_value_axis_order) + cached_key_scale_var = self.variable( "cache", "cached_prefill_key_scale", - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), + nn.with_logical_partitioning(jnp.zeros, key_layout), + key_shape_scale, jnp.bfloat16, ) cached_value_scale_var = self.variable( "cache", "cached_prefill_value_scale", - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), + nn.with_logical_partitioning(jnp.zeros, value_layout), + value_shape_scale, jnp.bfloat16, ) else: @@ -451,71 +508,67 @@ def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache): def _get_ar_cache(self, batch, heads, kv_head_size, quantize_kvcache): dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16 cache_length = self.max_target_length - self.max_prefill_predict_length - kv_cache_layout = ( - "cache_sequence", - "cache_heads", - "cache_batch", - "cache_kv", - ) + cache_logical_shape = (batch, cache_length, heads, kv_head_size) + key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_key_axis_order) + value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_value_axis_order) + + key_shape = self.cached_kv_shape(cache_logical_shape, self.ar_key_axis_order) + value_shape = self.cached_kv_shape(cache_logical_shape, self.ar_value_axis_order) + # TODO(b/339703100): investigate the issue why with_logical_partitioning doesn't enforce sharding cached_key = self.variable( "cache", "cached_ar_key", - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), + nn.with_logical_partitioning(jnp.zeros, key_layout), + key_shape, dtype, ) cached_key.value = nn.with_logical_constraint( cached_key.value, - ( - "cache_sequence", - "cache_heads", - "cache_batch", - "cache_kv", - ), + key_layout, ) cached_value = self.variable( "cache", "cached_ar_value", - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), + nn.with_logical_partitioning(jnp.zeros, value_layout), + value_shape, dtype, ) cached_value.value = nn.with_logical_constraint( cached_value.value, - ( - "cache_sequence", - "cache_heads", - "cache_batch", - "cache_kv", - ), + value_layout, ) cached_segment_id = self.variable( "cache", "cache_ar_segment_id", - nn.with_logical_partitioning(jnp.zeros, ("cache_batch", "cache_sequence")), + nn.with_logical_partitioning(jnp.zeros, (CACHE_BATCH, CACHE_SEQUENCE)), (cache_logical_shape[0], cache_length), jnp.int32, ) if self.quantize_kvcache: + cache_logical_shape_scale = (batch, cache_length, heads, 1) + + key_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.ar_key_axis_order) + value_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.ar_value_axis_order) + cached_key_scale_var = self.variable( "cache", "cached_ar_key_scale", - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), + nn.with_logical_partitioning(jnp.zeros, key_layout), + key_shape_scale, jnp.bfloat16, ) cached_value_scale_var = self.variable( "cache", "cached_ar_value_scale", - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), + nn.with_logical_partitioning(jnp.zeros, value_layout), + value_shape_scale, jnp.bfloat16, ) else: @@ -551,14 +604,22 @@ def kv_cache_prefill( cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache( batch, heads, kv_head_size, self.quantize_kvcache ) - self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) # initialize it now + cached_ar_key_var, cached_ar_value_var, _, _ = self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) # initialize it now - key_shaped_for_cache = self.move_kvlen_axis(key) - value_shaped_for_cache = self.move_kvlen_axis(value) + assert cached_prefill_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_key_axis_order) + assert cached_prefill_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_value_axis_order) + assert cached_ar_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_key_axis_order) + assert cached_ar_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_value_axis_order) + + prefill_key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_key_axis_order) + prefill_value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_value_axis_order) + + key_shaped_for_cache = self.reshape_kv_cache(key, self.prefill_key_axis_order) + value_shaped_for_cache = self.reshape_kv_cache(value, self.prefill_value_axis_order) if self.quantize_kvcache: - key_shaped_for_cache, key_scale = quantizations.quantize_kv(key_shaped_for_cache) - value_shaped_for_cache, value_scale = quantizations.quantize_kv(value_shaped_for_cache) + key_shaped_for_cache, key_scale = quantizations.quantize_kv(key_shaped_for_cache, prefill_key_layout.index(CACHE_KV)) + value_shaped_for_cache, value_scale = quantizations.quantize_kv(value_shaped_for_cache, prefill_value_layout.index(CACHE_KV)) cached_prefill_key_var[1].value = key_scale cached_prefill_value_var[1].value = value_scale @@ -595,48 +656,49 @@ def update_ar_key_value( cached_value_var, cached_value_scale_var = cached_value_vars # In order to update the key, value caches with the current key and - # value, we move the length axis to the back - one_token_key = self.move_kvlen_axis(one_token_key) - one_token_value = self.move_kvlen_axis(one_token_value) + # value, we reshape the one_token_key and one_token_value + one_token_key_shaped_for_cache = self.reshape_kv_cache(one_token_key, self.ar_key_axis_order) + one_token_value_shaped_for_cache = self.reshape_kv_cache(one_token_value, self.ar_value_axis_order) + + ar_key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_key_axis_order) + ar_value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_value_axis_order) if self.quantize_kvcache: - one_token_key, one_token_key_scale = quantizations.quantize_kv(one_token_key) - one_token_value, one_token_value_scale = quantizations.quantize_kv(one_token_value) + one_token_key_shaped_for_cache, one_token_key_scale = quantizations.quantize_kv(one_token_key_shaped_for_cache, ar_key_layout.index(CACHE_KV)) + one_token_value_shaped_for_cache, one_token_value_scale = quantizations.quantize_kv(one_token_value_shaped_for_cache, ar_value_layout.index(CACHE_KV)) one_hot_indices = one_hot_indices.astype(int) ar_key = cached_key_var.value - ar_key = jax.lax.dynamic_update_index_in_dim(ar_key, one_token_key, jnp.squeeze(one_hot_indices), 0) + ar_key = jax.lax.dynamic_update_index_in_dim(ar_key, one_token_key_shaped_for_cache, jnp.squeeze(one_hot_indices), ar_key_layout.index(CACHE_SEQUENCE)) ar_key = nn.with_logical_constraint( ar_key, - ( - "cache_sequence", - "cache_heads", - "cache_batch", - "cache_kv", - ), + ar_key_layout ) cached_key_var.value = ar_key ar_value = cached_value_var.value - ar_value = jax.lax.dynamic_update_index_in_dim(ar_value, one_token_value, jnp.squeeze(one_hot_indices), 0) + ar_value = jax.lax.dynamic_update_index_in_dim(ar_value, one_token_value_shaped_for_cache, jnp.squeeze(one_hot_indices), ar_key_layout.index(CACHE_SEQUENCE)) ar_value = nn.with_logical_constraint( ar_value, - ( - "cache_sequence", - "cache_heads", - "cache_batch", - "cache_kv", - ), + ar_value_layout, ) cached_value_var.value = ar_value if self.quantize_kvcache: ar_key_scale = jax.lax.dynamic_update_index_in_dim( - cached_key_scale_var.value, one_token_key_scale, jnp.squeeze(one_hot_indices), 0 + cached_key_scale_var.value, one_token_key_scale, jnp.squeeze(one_hot_indices), ar_key_layout.index(CACHE_SEQUENCE) + ) + ar_key_scale = nn.with_logical_constraint( + ar_key_scale, + ar_key_layout ) ar_value_scale = jax.lax.dynamic_update_index_in_dim( - cached_value_scale_var.value, one_token_value_scale, jnp.squeeze(one_hot_indices), 0 + cached_value_scale_var.value, one_token_value_scale, jnp.squeeze(one_hot_indices), ar_key_layout.index(CACHE_SEQUENCE) + ) + ar_value_scale = nn.with_logical_constraint( + ar_value_scale, + ar_value_layout ) cached_key_scale_var.value = ar_key_scale cached_value_scale_var.value = ar_value_scale @@ -644,16 +706,16 @@ def update_ar_key_value( ar_key = quantizations.unquantize_kv(cached_key_var.value, cached_key_scale_var.value, one_token_key.dtype) ar_value = quantizations.unquantize_kv(cached_value_var.value, cached_value_scale_var.value, one_token_value.dtype) - # Move the keys and values back to their original shapes. - return self.revert_kvlen_axis(ar_key), self.revert_kvlen_axis(ar_value) + # Revert the keys and values back to original logical shapes. + return self.revert_kv_cache(ar_key, self.ar_key_axis_order), self.revert_kv_cache(ar_value, self.ar_value_axis_order) - def prefill_cache_var_model_var(self, cache_var, target_dtype): + def prefill_cache_var_model_var(self, cache_var, target_dtype, cache_axis_order): if not self.quantize_kvcache: - return self.revert_kvlen_axis(cache_var[0].value) + return self.revert_kv_cache(cache_var[0].value, cache_axis_order) else: raw_cache, quant_scale = cache_var raw_cache_unquantized = quantizations.unquantize_kv(raw_cache.value, quant_scale.value, target_dtype) - return self.revert_kvlen_axis(raw_cache_unquantized) + return self.revert_kv_cache(raw_cache_unquantized, cache_axis_order) def kv_cache_autoregressive( self, @@ -684,6 +746,9 @@ def kv_cache_autoregressive( batch, heads, kv_head_size, self.quantize_kvcache ) + assert cached_ar_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_key_axis_order) + assert cached_ar_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_value_axis_order) + key = nn.with_logical_constraint(key, (BATCH, LENGTH, HEAD, D_KV)) value = nn.with_logical_constraint(value, (BATCH, LENGTH, HEAD, D_KV)) @@ -694,14 +759,16 @@ def kv_cache_autoregressive( ) cache_ar_index.value = jnp.mod(cache_ar_index.value + 1, self.max_target_length - self.max_prefill_predict_length) - # Prep and return both prefill and ar caches + # The below retrieves the existing prefill cache variables, not creating new ones cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache( - self.max_target_length, heads, kv_head_size, self.quantize_kvcache + batch, heads, kv_head_size, self.quantize_kvcache ) + assert cached_prefill_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_key_axis_order) + assert cached_prefill_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_value_axis_order) cached_prefill = ( - self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype), - self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype), + self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype, self.prefill_key_axis_order), + self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype, self.prefill_value_axis_order), cached_prefill_segment_id.value, ) return cached_prefill, (ar_key, ar_value, cached_ar_segment_id.value) @@ -709,15 +776,15 @@ def kv_cache_autoregressive( def kv_cache(self, key: Array, value: Array, decoder_segment_ids: Array, model_mode: str) -> tuple: """KV cache takes the current state and updates the state accordingly. - The key and value have dimension [batch, length, num_heads, head_dim], - but we cache them as [batch, num_heads, head_dim, length] as a TPU + The key and value have dimension [b, s, n_kv, d], + but we cache them with a reshape as defined in *_axis_order config as a TPU fusion optimization. This also enables the "scatter via one-hot broadcast" trick, which means we do a one-hot broadcast instead of a scatter/gather operations, resulting in a 3-4x speedup in practice. Args: - key: in shape [b, s, n, d]. - value: in shape [b, s, n, d]. + key: in shape [b, s, n_kv, d]. + value: in shape [b, s, n_kv, d]. model_mode: model mode controlling model Returns: @@ -837,6 +904,11 @@ class Attention(nn.Module): value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) + prefill_key_axis_order: AxisIdxes = (1, 2, 0, 3) + prefill_value_axis_order: AxisIdxes = (1, 2, 0, 3) + ar_key_axis_order: AxisIdxes = (1, 2, 0, 3) + ar_value_axis_order: AxisIdxes = (1, 2, 0, 3) + def query_projection(self, inputs_q: Array) -> Array: """Query projection.""" @@ -991,6 +1063,10 @@ def __call__( num_kv_heads=self.num_kv_heads, dropout_rate=self.dropout_rate, dtype=self.dtype, + prefill_key_axis_order = self.prefill_key_axis_order, + prefill_value_axis_order = self.prefill_value_axis_order, + ar_key_axis_order = self.ar_key_axis_order, + ar_value_axis_order = self.ar_value_axis_order, ) out = attention_op(query, key, value, decoder_segment_ids, model_mode) diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index 7fbcf4d5a..157bdf78f 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -96,6 +96,10 @@ def __call__( name="self_attention", quant=self.quant, quantize_kvcache=cfg.quantize_kvcache, + prefill_key_axis_order=tuple([int(i) for i in cfg.prefill_key_axis_order.split(",")]), + prefill_value_axis_order=tuple([int(i) for i in cfg.prefill_value_axis_order.split(",")]), + ar_key_axis_order=tuple([int(i) for i in cfg.ar_key_axis_order.split(",")]), + ar_value_axis_order=tuple([int(i) for i in cfg.ar_value_axis_order.split(",")]), ) attention_lnx = attention_layer( diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index c778e2d38..a9e1d0e63 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -92,6 +92,10 @@ def __call__( name="self_attention", quant=self.quant, quantize_kvcache=cfg.quantize_kvcache, + prefill_key_axis_order=tuple([int(i) for i in cfg.prefill_key_axis_order.split(",")]), + prefill_value_axis_order=tuple([int(i) for i in cfg.prefill_value_axis_order.split(",")]), + ar_key_axis_order=tuple([int(i) for i in cfg.ar_key_axis_order.split(",")]), + ar_value_axis_order=tuple([int(i) for i in cfg.ar_value_axis_order.split(",")]), ) attention_lnx = attention_layer( diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index 9631e662b..3f169e2fb 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -176,9 +176,9 @@ def configure_kv_quantization(config: Config): return False if not config.quantize_kvcache else True -def quantize_kv(kv: Array): +def quantize_kv(kv: Array, kv_axis: int): """Quantize key/values stored in kvcache.""" - scale = jnp.max(jnp.abs(kv), axis=-1, keepdims=True) + scale = jnp.max(jnp.abs(kv), axis=kv_axis, keepdims=True) value = jnp.int8(jnp.rint(kv * (MAX_INT8 / scale))) return value, scale diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index 2a4e2ab97..2692cdad9 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -58,6 +58,7 @@ def setUp(self): self.num_kv_heads = self.cfg.num_kv_heads self.num_query_heads = self.cfg.num_query_heads self.max_target_length = self.cfg.max_target_length + self.max_prefill_predict_length = self.cfg.max_prefill_predict_length self.head_dim = self.cfg.head_dim self.embed_dim = self.cfg.base_emb_dim self.dtype = self.cfg.dtype @@ -255,6 +256,119 @@ def tpu_kernel_attention_helper(self, num_kv_heads): jax.numpy.allclose(mha_generic_output, mha_generic_flash_output, rtol=1e-01, atol=1e-01, equal_nan=False) ) + @pytest.mark.tpu + def test_dot_product_1203_1203(self): + self.dot_product_attention_helper( + prefill_cache_axis_order=(1,2,0,3), + ar_cache_axis_order=(1,2,0,3) + ) + + @pytest.mark.tpu + def test_dot_product_1203_2130(self): + self.dot_product_attention_helper( + prefill_cache_axis_order=(1,2,0,3), + ar_cache_axis_order=(2,1,3,0) + ) + + @pytest.mark.tpu + def test_dot_product_2130_1203(self): + self.dot_product_attention_helper( + prefill_cache_axis_order=(2,1,3,0), + ar_cache_axis_order=(1,2,0,3) + ) + + @pytest.mark.tpu + def test_dot_product_2130_2130(self): + self.dot_product_attention_helper( + prefill_cache_axis_order=(2,1,3,0), + ar_cache_axis_order=(2,1,3,0), + ) + + def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order): + self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=False) + self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=True) + + def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache): + """Test equalvant between dot_product and TPU accelerated""" + prefill_length = self.max_prefill_predict_length + decode_total_length = self.max_target_length + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) + + lnx_prefill = lnx[:, 0:prefill_length, :] + decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] + decoder_positions_prefill = decoder_positions[:, 0:prefill_length] + + attention_w_layout = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + prefill_key_axis_order=prefill_cache_axis_order, + prefill_value_axis_order=prefill_cache_axis_order, + ar_key_axis_order=ar_cache_axis_order, + ar_value_axis_order=ar_cache_axis_order, + quantize_kvcache=quantize_kvcache, + ) + + attention_w_layout_variable = attention_w_layout.init( + {"params": self.rng, "aqt": self.rng}, + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length)), + ) + + attention_w_layout_full = attention_w_layout.apply( + attention_w_layout_variable, + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=common_types.MODEL_MODE_TRAIN, + rngs={"aqt": self.rng}, + ) + + attention_w_layout_prefill, attention_w_layout_output_cache = attention_w_layout.apply( + attention_w_layout_variable, + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"aqt": self.rng}, + mutable=["cache"], + ) + self.assertTrue( + jax.numpy.allclose(attention_w_layout_full[:, :prefill_length, :], attention_w_layout_prefill, equal_nan=False) + ) + + for idx in range(prefill_length, decode_total_length): + + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] + + attention_w_layout_variable.update(attention_w_layout_output_cache) + attention_w_layout_idx, attention_w_layout_output_cache = attention_w_layout.apply( + attention_w_layout_variable, + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + rngs={"aqt": self.rng}, + mutable=["cache"], + ) + + attention_w_layout_full_this_idx = attention_w_layout_full[:, idx : idx + 1, :] + self.assertTrue(attention_w_layout_full_this_idx.shape == attention_w_layout_idx.shape) + self.assertTrue(jax.numpy.allclose(attention_w_layout_full_this_idx, attention_w_layout_idx, rtol=1e-02, atol=1e-01, equal_nan=False)) + if __name__ == "__main__": unittest.main()