-
Notifications
You must be signed in to change notification settings - Fork 230
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add forward pass logit check test for Llama2-7b
- Loading branch information
Showing
8 changed files
with
256 additions
and
0 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0d13ebbb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n", | ||
"!pip3 install tokenizers -U\n", | ||
"!pip3 install transformers -U\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6a8a4bb6", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch \n", | ||
"from transformers import AutoTokenizer, AutoModelForCausalLM \n", | ||
"import json" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ff804403", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load the tokenizer and model from Hugging Face \n", | ||
" \n", | ||
"model_id = \"meta-llama/Llama-2-7b-hf\"\n", | ||
"\n", | ||
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n", | ||
"model = AutoModelForCausalLM.from_pretrained(\n", | ||
" model_id,\n", | ||
" torch_dtype=torch.float32,\n", | ||
")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9f218ba6", | ||
"metadata": {}, | ||
"source": [ | ||
"## looping over multiple prompts and logits" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c567f8d9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Save to disk \n", | ||
"output_path = \"golden_data_llama2-7b.jsonl\" \n", | ||
" \n", | ||
" \n", | ||
"# Your prompt text \n", | ||
"prompt_texts = [\"I love to\", \"Today is a\", \"What is the\" ]\n", | ||
"\n", | ||
"for prompt_text in prompt_texts:\n", | ||
" # Encode the prompt text \n", | ||
" input_ids = tokenizer.encode(prompt_text, return_tensors='pt') \n", | ||
"\n", | ||
" with torch.no_grad(): \n", | ||
" # Greedy decoding \n", | ||
" output = model.generate(input_ids, max_length=input_ids.shape[1] + 10, num_return_sequences=1) \n", | ||
" \n", | ||
" # Decode the generated ids to a list of tokens \n", | ||
" generated_tokens = tokenizer.convert_ids_to_tokens(output[0]) \n", | ||
" print(generated_tokens)\n", | ||
"\n", | ||
" # Get the logits for the prompt + completion \n", | ||
" with torch.no_grad(): \n", | ||
" outputs = model(output) \n", | ||
" logits = outputs.logits \n", | ||
" \n", | ||
" # Convert logits to fp32 \n", | ||
" logits = logits.cpu().numpy().astype('float32') \n", | ||
"\n", | ||
" # Prepare data to be saved \n", | ||
" data_to_save = { \n", | ||
" \"prompt\": prompt_text, \n", | ||
" \"completion\": tokenizer.decode(output[0]), \n", | ||
" \"tokens\": generated_tokens, \n", | ||
" \"logits\": logits.tolist() # Convert numpy array to list for JSON serialization \n", | ||
" } \n", | ||
" \n", | ||
" with open(output_path, 'w') as f: \n", | ||
" json.dump(data_to_save, f) \n", | ||
"\n", | ||
" \n", | ||
"\n", | ||
" print(f\"Data saved to {output_path}\") \n", | ||
"\n", | ||
"\n", | ||
" \n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
File renamed without changes.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
# This forward_pass_logit_checker.py file compares the logits generated by MaxText implementation for some input prompts | ||
# with the golden logits for those input prompts for a particular model. This forward_pass_logit_checker.py is generic that | ||
# it can work with different models and expects an input file called golden_data_<model_name>.jsonl to be present | ||
# under MaxText/test_assets | ||
# For e.g., MaxText/test_assets/golden_data_llama2-7b.jsonl | ||
# The golden jsonl file is a simple jsonlines file with each line is in the format of a dictionary containing the following | ||
# required keys: | ||
# 1. prompt: A string representing the prompt, for e.g., "I love to", | ||
# 2. tokens: token ids after tokenizing the prompt, | ||
# 3. logits: golden logits meaning the ideal logits generated by the model in question when fed with the prompt in #1 | ||
# There can be multiple such test cases in the jsonl file, each test case is a new line in the jsonl file | ||
# This forward_pass_logit_checker.py runs the forward pass with the input tokens and asserts that the logits generated by the | ||
# MaxText implementation of the same model matches the golden logits closely | ||
# Users could use a script similar to MaxText/scratch_code/golden_llama2-7b_export.ipynb to create this jsonl file | ||
|
||
"""Check if the logits generated by a model's MaxText implementation matches golden logits for the same inputs""" | ||
import sys | ||
import os | ||
|
||
current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
maxtext_parent_dir = os.path.dirname(current_dir) | ||
sys.path.append(maxtext_parent_dir) | ||
|
||
import max_logging | ||
max_logging.log(f"Added parent directory = {maxtext_parent_dir}") | ||
|
||
import common_types | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import pyconfig | ||
import jsonlines | ||
import train | ||
|
||
|
||
def get_data(golden_data, golden_data_index, config): | ||
""" Get the golden data for the test indexed at golden_data_index""" | ||
|
||
max_logging.log(f"Comparing forward pass for golden data index = {golden_data_index} ") | ||
max_logging.log(f"config.global_batch_size_to_train_on={config.global_batch_size_to_train_on}") | ||
s = (config.global_batch_size_to_train_on, config.max_target_length) | ||
ids = np.asarray(golden_data[golden_data_index]['tokens'], dtype=np.int32) | ||
|
||
logits = np.asarray(golden_data[golden_data_index]['logits'], dtype=np.float32) | ||
max_logging.log(f" prompt=\"{golden_data[golden_data_index]['prompt']}\" raw ids={ids}, logits.shape = {logits.shape}") | ||
|
||
|
||
decoder_segment_ids = jax.numpy.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR | ||
decoder_positions = jnp.stack( | ||
[jnp.arange(config.max_target_length, dtype=jnp.int32) for _ in range(config.global_batch_size_to_train_on)] | ||
) | ||
|
||
ids = jnp.stack( | ||
[ids for _ in range(config.global_batch_size_to_train_on)] | ||
) | ||
max_logging.log(f"ids={ids}, decoder_segment_ids = {decoder_segment_ids}, decoder_positions= {decoder_positions}") | ||
|
||
return ids, decoder_segment_ids, decoder_positions, logits | ||
|
||
def main(config): | ||
"""Test the Whole Model of model_name""" | ||
|
||
#initialize the Llama2-7b model with weights from Meta | ||
( | ||
init_rng, | ||
_, | ||
_, | ||
_, | ||
model, | ||
_, | ||
_, | ||
_, | ||
_, | ||
state, | ||
) = train.setup_train_loop(config) | ||
|
||
input_golden_data_path = "MaxText/test_assets/golden_data_"+config.model_name+".jsonl" | ||
with jsonlines.open(input_golden_data_path, 'r') as f: | ||
golden_data = list(f) | ||
|
||
|
||
for golden_data_index in range(len(golden_data)): | ||
ids, decoder_segment_ids, decoder_positions, golden_logits = get_data(golden_data,golden_data_index,config) | ||
|
||
full_train_logits = model.apply( | ||
state.params, | ||
ids, | ||
decoder_positions, | ||
decoder_segment_ids, | ||
enable_dropout=False, | ||
rngs={"aqt": init_rng}, | ||
) | ||
|
||
max_logging.log(f"{golden_logits[0] =}, {full_train_logits[0, 0, :]=}") | ||
assert jax.numpy.allclose( | ||
full_train_logits[0, :, :], golden_logits, rtol=1e-01, atol=1e-01, equal_nan=False | ||
) | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
jax.config.update("jax_default_prng_impl", "unsafe_rbg") | ||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" | ||
pyconfig.initialize(sys.argv) | ||
cfg = pyconfig.config | ||
main(cfg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters