Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save and load quantized checkpoint #671

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

jwyang-google
Copy link
Collaborator

No description provided.

# )

restored = ckptr.restore(p)
print(restored['params'].keys()) # printed ['params', 'aqt']
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@singh-mitali The original restore call can only load the 'params' key, but the saved checkpoint we currently have puts the weights in both 'params' and 'aqt'

This change can load the checkpoint with both 'params' and 'aqt'. But it has some sharding issue when actually running. I guess it's because we are missing the second argument items as in the original restore() call. Not sure how to split the abstract_unboxed_pre_state.params to match our new quantized checkpoint

Copy link
Collaborator Author

@jwyang-google jwyang-google May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way we probably can save the quantized checkpoint as the original format so we don't need to change the loading code, but I found path keys of quantized checkpoint are different under 'aqt' compared to the original checkpoint under 'params'

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes they are different - they are generated by the aqt code and same paths will be expected by aqt code to read once we load the checkpoint. We would have to do the conversion on both ends in that case (also these paths can change due to aqt code changes or which layers we apply aqt to). There are the int8 values for weights but also an additional parameter scale per quantized vector. Better to see if we can just save the two parameters separately in checkpoint.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So could create a state called QuantizedDecodeState which has both params and aqt as variables. Store that in checkpoint. Alternatively use multi-object checkpointing https://colab.sandbox.google.com/github/google/orbax/blob/main/checkpoint/orbax/checkpoint/orbax_checkpoint.ipynb

state.params['params'] = params['params']
state.params['aqt'] = params['aqt']
if save_checkpoint(checkpoint_manager, step_number_to_save_new_ckpt, state):
print('save checkpoint successfully')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can successfully save a quantized checkpoint

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants