Skip to content

Commit

Permalink
Enable kv cache layout control
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Jun 3, 2024
1 parent f12ba54 commit 57429da
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 108 deletions.
5 changes: 5 additions & 0 deletions MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@
ScanIn = partitioning.ScanIn

AxisNames = tuple[str, ...]
AxisIdxes = tuple[int, ...]

BATCH = "activation_batch"
LENGTH = "activation_length"
HEAD = "activation_heads"
D_KV = "activation_kv"
CACHE_BATCH = "cache_batch"
CACHE_SEQUENCE = "cache_sequence"
CACHE_HEADS = "cache_heads"
CACHE_KV = "cache_kv"

MODEL_MODE_AUTOREGRESSIVE = "autoregressive"
MODEL_MODE_PREFILL = "prefill"
Expand Down
8 changes: 8 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,14 @@ inference_microbenchmark_stages: "prefill,generate"
inference_microbenchmark_loop_iters: 10
inference_microbenchmark_log_file_path: ""

# KV Cache layout control
# Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV
# Default layout: 1,2,0,3 ; CACHE_SEQUENCE, CACHE_HEADS, CACHE_BATCH, CACHE_KV
prefill_key_axis_order: "1,2,0,3"
prefill_value_axis_order: "1,2,0,3"
ar_key_axis_order: "1,2,0,3"
ar_value_axis_order: "1,2,0,3"

# Maxengine Metrics
prometheus_port: 0

Expand Down
Loading

0 comments on commit 57429da

Please sign in to comment.