-
Notifications
You must be signed in to change notification settings - Fork 244
/
forward_pass_logit_checker.py
121 lines (97 loc) · 4.62 KB
/
forward_pass_logit_checker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This forward_pass_logit_checker.py file compares the logits generated by MaxText implementation for some input prompts
# with the golden logits for those input prompts for a particular model. This forward_pass_logit_checker.py is generic that
# it can work with different models and expects an input file called golden_data_<model_name>.jsonl to be present
# under MaxText/test_assets
# For e.g., MaxText/test_assets/golden_data_llama2-7b.jsonl
# The golden jsonl file is a simple jsonlines file with each line is in the format of a dictionary containing the following
# required keys:
# 1. prompt: A string representing the prompt, for e.g., "I love to",
# 2. tokens: token ids after tokenizing the prompt,
# 3. logits: golden logits meaning the ideal logits generated by the model in question when fed with the prompt in #1
# There can be multiple such test cases in the jsonl file, each test case is a new line in the jsonl file
# This forward_pass_logit_checker.py runs the forward pass with the input tokens and asserts that the logits generated by the
# MaxText implementation of the same model matches the golden logits closely
# Users could use a script similar to MaxText/scratch_code/golden_llama2-7b_export.ipynb to create this jsonl file
"""Check if the logits generated by a model's MaxText implementation matches golden logits for the same inputs"""
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
maxtext_parent_dir = os.path.dirname(current_dir)
sys.path.append(maxtext_parent_dir)
import max_logging
max_logging.log(f"Added parent directory = {maxtext_parent_dir}")
import common_types
import jax
import jax.numpy as jnp
import numpy as np
import pyconfig
import jsonlines
import train
def get_data(golden_data, golden_data_index, config):
""" Get the golden data for the test indexed at golden_data_index"""
max_logging.log(f"Comparing forward pass for golden data index = {golden_data_index} ")
max_logging.log(f"config.global_batch_size_to_train_on={config.global_batch_size_to_train_on}")
s = (config.global_batch_size_to_train_on, config.max_target_length)
ids = np.asarray(golden_data[golden_data_index]['tokens'], dtype=np.int32)
logits = np.asarray(golden_data[golden_data_index]['logits'], dtype=np.float32)
max_logging.log(f" prompt=\"{golden_data[golden_data_index]['prompt']}\" raw ids={ids}, logits.shape = {logits.shape}")
decoder_segment_ids = jax.numpy.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR
decoder_positions = jnp.stack(
[jnp.arange(config.max_target_length, dtype=jnp.int32) for _ in range(config.global_batch_size_to_train_on)]
)
ids = jnp.stack(
[ids for _ in range(config.global_batch_size_to_train_on)]
)
max_logging.log(f"ids={ids}, decoder_segment_ids = {decoder_segment_ids}, decoder_positions= {decoder_positions}")
return ids, decoder_segment_ids, decoder_positions, logits
def main(config):
"""Test the Whole Model of model_name"""
#initialize the Llama2-7b model with weights from Meta
(
init_rng,
_,
_,
_,
model,
_,
_,
_,
_,
state,
) = train.setup_train_loop(config)
input_golden_data_path = "MaxText/test_assets/golden_data_"+config.model_name+".jsonl"
with jsonlines.open(input_golden_data_path, 'r') as f:
golden_data = list(f)
for golden_data_index in range(len(golden_data)):
ids, decoder_segment_ids, decoder_positions, golden_logits = get_data(golden_data,golden_data_index,config)
full_train_logits = model.apply(
state.params,
ids,
decoder_positions,
decoder_segment_ids,
enable_dropout=False,
rngs={"aqt": init_rng},
)
max_logging.log(f"{golden_logits[0] =}, {full_train_logits[0, 0, :]=}")
assert jax.numpy.allclose(
full_train_logits[0, :, :], golden_logits, rtol=1e-01, atol=1e-01, equal_nan=False
)
if __name__ == "__main__":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(sys.argv)
cfg = pyconfig.config
main(cfg)