From 2ca0d24774a696a37b6b9d31813db44662b90574 Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Sat, 18 May 2024 00:51:52 +0000 Subject: [PATCH] layout api --- MaxText/inference_microbenchmark.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py index 3e04c6c6e..701731cae 100644 --- a/MaxText/inference_microbenchmark.py +++ b/MaxText/inference_microbenchmark.py @@ -16,6 +16,8 @@ """Inference microbenchmark for prefill and autoregressive steps.""" import datetime +import re + import jax import json import sys @@ -31,6 +33,15 @@ _WARMUP_ITERS = 2 +pattern = re.compile(r"\{(.*?):") + +# Extract minor_to_major from str(layout) because layout doesn't have a +# minor_to_major property yet. +def extract_minor_to_major(l): + match = re.search(pattern, str(l)) + return tuple(int(i) for i in match.groups()[0].split(',')) + + def prefill_benchmark_loop(engine, params, tokens, true_length, iters): """Inner loop for benchmarking prefill step.""" start = datetime.datetime.now() @@ -121,12 +132,17 @@ def ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name) return (end - start).total_seconds(), decode_state +def ar_lowering(engine, params, decode_state): + lowered_generate = engine.generate.lower(params, decode_state) + compiled_generate = lowered_generate.compile() + return compiled_generate + + def ar_benchmark(config, engine, params, decode_state, global_batch_size, cache_size, model_size, iters): """Handles warmup, running ar benchmark, and printing results.""" for _ in range(_WARMUP_ITERS): decode_state, _ = engine.generate(params, decode_state) jax.block_until_ready(decode_state) - time_in_s, decode_state = ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name="autoregress") seconds_per_step = time_in_s / iters ar_average_ms = seconds_per_step * 1000 @@ -270,8 +286,11 @@ def main(config): ) if "generate" in stages_to_benchmark: - benchmark_results["AutoRegressive"], decode_state = ar_benchmark( - config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters) + compiled_generate = ar_lowering(engine, params, decode_state) + breakpoint() + + # benchmark_results["AutoRegressive"], decode_state = ar_benchmark( + # config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters) results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params) write_results(results, filename=config.inference_microbenchmark_log_file_path)