Skip to content

Commit

Permalink
Add forward pass logit check test for Llama2-7b
Browse files Browse the repository at this point in the history
  • Loading branch information
A9isha committed May 10, 2024
1 parent e327805 commit 7434359
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 0 deletions.
File renamed without changes.
File renamed without changes.
128 changes: 128 additions & 0 deletions MaxText/scratch_code/golden_llama2-7b_export.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0d13ebbb",
"metadata": {},
"outputs": [],
"source": [
"!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n",
"!pip3 install tokenizers -U\n",
"!pip3 install transformers -U\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a8a4bb6",
"metadata": {},
"outputs": [],
"source": [
"import torch \n",
"from transformers import AutoTokenizer, AutoModelForCausalLM \n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff804403",
"metadata": {},
"outputs": [],
"source": [
"# Load the tokenizer and model from Hugging Face \n",
" \n",
"model_id = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_id,\n",
" torch_dtype=torch.float32,\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "9f218ba6",
"metadata": {},
"source": [
"## looping over multiple prompts and logits"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c567f8d9",
"metadata": {},
"outputs": [],
"source": [
"# Save to disk \n",
"output_path = \"golden_data_llama2-7b.jsonl\" \n",
" \n",
" \n",
"# Your prompt text \n",
"prompt_texts = [\"I love to\", \"Today is a\", \"What is the\" ]\n",
"\n",
"for prompt_text in prompt_texts:\n",
" # Encode the prompt text \n",
" input_ids = tokenizer.encode(prompt_text, return_tensors='pt') \n",
"\n",
" with torch.no_grad(): \n",
" # Greedy decoding \n",
" output = model.generate(input_ids, max_length=input_ids.shape[1] + 10, num_return_sequences=1) \n",
" \n",
" # Decode the generated ids to a list of tokens \n",
" generated_tokens = tokenizer.convert_ids_to_tokens(output[0]) \n",
" print(generated_tokens)\n",
"\n",
" # Get the logits for the prompt + completion \n",
" with torch.no_grad(): \n",
" outputs = model(output) \n",
" logits = outputs.logits \n",
" \n",
" # Convert logits to fp32 \n",
" logits = logits.cpu().numpy().astype('float32') \n",
"\n",
" # Prepare data to be saved \n",
" data_to_save = { \n",
" \"prompt\": prompt_text, \n",
" \"completion\": tokenizer.decode(output[0]), \n",
" \"tokens\": generated_tokens, \n",
" \"logits\": logits.tolist() # Convert numpy array to list for JSON serialization \n",
" } \n",
" \n",
" with open(output_path, 'w') as f: \n",
" json.dump(data_to_save, f) \n",
"\n",
" \n",
"\n",
" print(f\"Data saved to {output_path}\") \n",
"\n",
"\n",
" \n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
3 changes: 3 additions & 0 deletions MaxText/test_assets/golden_data_llama2-7b.jsonl

Large diffs are not rendered by default.

121 changes: 121 additions & 0 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# This forward_pass_logit_checker.py file compares the logits generated by MaxText implementation for some input prompts
# with the golden logits for those input prompts for a particular model. This forward_pass_logit_checker.py is generic that
# it can work with different models and expects an input file called golden_data_<model_name>.jsonl to be present
# under MaxText/test_assets
# For e.g., MaxText/test_assets/golden_data_llama2-7b.jsonl
# The golden jsonl file is a simple jsonlines file with each line is in the format of a dictionary containing the following
# required keys:
# 1. prompt: A string representing the prompt, for e.g., "I love to",
# 2. tokens: token ids after tokenizing the prompt,
# 3. logits: golden logits meaning the ideal logits generated by the model in question when fed with the prompt in #1
# There can be multiple such test cases in the jsonl file, each test case is a new line in the jsonl file
# This forward_pass_logit_checker.py runs the forward pass with the input tokens and asserts that the logits generated by the
# MaxText implementation of the same model matches the golden logits closely
# Users could use a script similar to MaxText/scratch_code/golden_llama2-7b_export.ipynb to create this jsonl file

"""Check if the logits generated by a model's MaxText implementation matches golden logits for the same inputs"""
import sys
import os

current_dir = os.path.dirname(os.path.abspath(__file__))
maxtext_parent_dir = os.path.dirname(current_dir)
sys.path.append(maxtext_parent_dir)

import max_logging
max_logging.log(f"Added parent directory = {maxtext_parent_dir}")

import common_types
import jax
import jax.numpy as jnp
import numpy as np
import pyconfig
import jsonlines
import train


def get_data(golden_data, golden_data_index, config):
""" Get the golden data for the test indexed at golden_data_index"""

max_logging.log(f"Comparing forward pass for golden data index = {golden_data_index} ")
max_logging.log(f"config.global_batch_size_to_train_on={config.global_batch_size_to_train_on}")
s = (config.global_batch_size_to_train_on, config.max_target_length)
ids = np.asarray(golden_data[golden_data_index]['tokens'], dtype=np.int32)

logits = np.asarray(golden_data[golden_data_index]['logits'], dtype=np.float32)
max_logging.log(f" prompt=\"{golden_data[golden_data_index]['prompt']}\" raw ids={ids}, logits.shape = {logits.shape}")


decoder_segment_ids = jax.numpy.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR
decoder_positions = jnp.stack(
[jnp.arange(config.max_target_length, dtype=jnp.int32) for _ in range(config.global_batch_size_to_train_on)]
)

ids = jnp.stack(
[ids for _ in range(config.global_batch_size_to_train_on)]
)
max_logging.log(f"ids={ids}, decoder_segment_ids = {decoder_segment_ids}, decoder_positions= {decoder_positions}")

return ids, decoder_segment_ids, decoder_positions, logits

def main(config):
"""Test the Whole Model of model_name"""

#initialize the Llama2-7b model with weights from Meta
(
init_rng,
_,
_,
_,
model,
_,
_,
_,
_,
state,
) = train.setup_train_loop(config)

input_golden_data_path = "MaxText/test_assets/golden_data_"+config.model_name+".jsonl"
with jsonlines.open(input_golden_data_path, 'r') as f:
golden_data = list(f)


for golden_data_index in range(len(golden_data)):
ids, decoder_segment_ids, decoder_positions, golden_logits = get_data(golden_data,golden_data_index,config)

full_train_logits = model.apply(
state.params,
ids,
decoder_positions,
decoder_segment_ids,
enable_dropout=False,
rngs={"aqt": init_rng},
)

max_logging.log(f"{golden_logits[0] =}, {full_train_logits[0, 0, :]=}")
assert jax.numpy.allclose(
full_train_logits[0, :, :], golden_logits, rtol=1e-01, atol=1e-01, equal_nan=False
)



if __name__ == "__main__":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(sys.argv)
cfg = pyconfig.config
main(cfg)
3 changes: 3 additions & 0 deletions end_to_end/tpu/llama2/7b/test_llama2_7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ export NEW_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${PARAMETER_CHECKPOINT_RUN}/checkp

# We run decoding on the fine-tuned parameter checkpoint
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${NEW_CKPT_PATH} run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false

# We also test whether the forward pass logits match the golden logits for Llama2-7b
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ tensorboardx
tensorboard-plugin-profile
git+https://github.com/mlperf/logging.git
google-jetstream
jsonlines

0 comments on commit 7434359

Please sign in to comment.