-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Improve TorchTune Extensibility and Build Interop with Ecosystem #442
Conversation
✅ Deploy Preview for torchtune-preview ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
High level comment - totally agree with the proposal to support HF formats for popular models as well. Though wanted to get into detail about this point -
Curious how the HF format specifically enables easier support for llama2 13B and 30B? AFAIK, meta format checkpoints are published for these models as well, so we should be able to use those formats just as easily as well. Just wondering whether the HF format provides something additional for the larger scale models specifically. |
@rohan-varma Sorry this was poorly worded. Adding the |
@kartikayk That makes sense. It seems like we should discuss how we should read and write multiple files? IIUC, the files are checkpoint shards, so what I've been doing is using distributed to stitch them together so we can load into torchtune models, i.e. in https://github.com/pytorch-labs/torchtune/pull/404/files. Curious about the approach that you took to load 13b? |
Thanks for creating this RFC! A few questions I have on the proposal:
Sorry this wound up being more than a few questions .. 😅 .. let me know if any of these comments are unclear. |
@rohan-varma why do I need distributed for this? Here's a simple code pointer from gpt-fast on how they handle the |
@ebsmothers great questions! A few thoughts:
If you take a close rlook, the format isn't model specific. It's "training framework" specific. Adding support for a model like Mistral isn't a new format. It's a new convert function which likely has a similar/same key mapping dict.
We'll need to figure this out once we have a concrete method that we can take a look at. But the answer is that we should align with the inference tooling on this. And that's the contract I propose here.
I view this as an interaction with the ecosystem which happens to depend on HF formats for everything. You're right that if things break that we'll have to go in and fix stuff, but the claim is that this is true generally. If tomorrow some other format becomes popular, we should just align with that. Does this make sense?
A few thoughts on this:
Simply put - format of state dict that you can load into our model class i.e.
Responded a bit, I don't think I understand what "distributed fashion means here". See the code pointer above, the files in the HF repo for 13, 70B models. Maybe I'm missing something, so need to understand this more. |
So IIUC, in those repos, what @ebsmothers means by the files are in a distributed fashion is that each model related file is a sharded checkpoint of the entire model. For example, llama 2 70b checkpoint has 8 checkpoint files: https://huggingface.co/meta-llama/Llama-2-70b/tree/main. Each file contains the rank's owning shard from row / column parallel sharding. When converting to / from torchtune format, we'll need to appropriately unshard - which to me, requires being aware if the particular tensor was either column or row parallel sharded. |
I don't think this is necessarily true for the HF checkpoint, but I also need to do a lot more HW for 70B. Let's punt the discussion on 70B though. I think the complexity that 70B brings is orthogonal to this RFC. I'm also in the process of redesigning the entire checkpointing stack since this is sub-optimally written in its current form. This will allow us to decouple the checkpointing across model sizes (7b vs 70b) and across finetuning methods (full vs Lora), which makes sense to me since a user fine-tuning a 7B or smaller model doesnt need to care about any of the complexity associated with the 70B model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FullModelCheckpointer
generally makes sense to me, but I also think this is the easy case 😃. The main thing I'd push on here is more details around common PEFT partial save/load flows and making sure we properly support distributed checkpointing APIs (on that second point I would take a look at #443 if you haven't already)
|
||
# different formats for the same model have different dtypes. Track this | ||
# to convert weights to the right format before writing to file | ||
self._checkpoint_dtype = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any use case for different weights in different dtypes?
if intermediate_checkpoint: | ||
# Add the relevant checkpoint format information to the state dict | ||
checkpoint_dict["checkpoint_dtype"] = self._checkpoint_dtype | ||
checkpoint_dict["weight_map"] = self._weight_map | ||
checkpoint_dict["checkpoint_format"] = self._checkpoint_format.name | ||
|
||
# We write to a single ".pt" file irrespective of the extension provided by the recipe | ||
output_path = Path.joinpath(self._output_dir, intermediate_checkpoint_name).with_suffix(".pt") | ||
torch.save(checkpoint_dict, output_path) | ||
logger.info( | ||
"Model checkpoint of size " | ||
f"{os.path.getsize(output_path) / 1000**3:.2f} GB " | ||
f"saved to {output_path}" | ||
) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more of a nit, but if the whole method is an if/else maybe it should be split into e.g. save_intermediate_checkpoint
and save_final_checkpoint
methods?
* Mid-training Chekpointing. In this case the state-dict contains more information | ||
than just the model weights. It also contains the optimizer state dict and the training | ||
state of the recipe needed to correctly restart training. The construction of the | ||
state-dict, including figuring out what information is needed to correctly resume | ||
training, is handled by the recipe. The checkpointer doesn't know or care about how the | ||
state-dict is constructed. In this scenario, the checkpointer simply adds additional | ||
information about the original checkpoint (eg: format, weight map etc) to ensure the | ||
final checkpoint is constructured correctly in case the current training run fails and | ||
needs to be resumed. Intermediate checkpoints don't require any conversion since these | ||
are directly saved in the ``TORCHTUNE_FORMAT``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the claim is that we are using this Checkpointer class to define and enforce a unified TorchTune format, the mid-training checkpointing contract here feels a little strange. It also introduces a lot of coupling between the recipe and the checkpointing class (for instance, the recipe is responsible for defining the logic to save an intermediate checkpoint, but the checkpointer then needs to know about this when loading the intermediate checkpoint). Lmk if I'm misunderstanding this though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers actually the intent here is opposite of what you mentioned. Right now the checkpointer and the recipe are really intertwined with the checkpointing under utils having to know about recipe specific keys and how the checkpoint is formatted. With this change, the separation of concerns is a lot cleaner i.e:
- The recipe handles all of the logic for preparing intermediate checkpoints (it knows what state it needs to resume correctly)
- The checkpointer handling the external checkpoint loading and final checkpoint saving.
Generally there will be some coupling between the two because the checkpointer is meant to be used in the recipe. But generally checkpointers should be recipe agnostic. Let me know if this makes sense.
for key in state_dict.keys(): | ||
self._weight_map[key] = ckpt_file.name | ||
merged_state_dict.update(state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if params are sharded across files? I.e. each individual file contains a subset of the weights for a single state dict key
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a new convertor. I don't have a case for this right now till we add support for 13B or 70B. But I expect that to go in a different checkpointer class (or maybe we can update this one if it makes sense).
@ebsmothers thank you for the review
So this PR won't address this complexity. My main point is that having a separate training component will make things easier to separate i.e. not have FFT users worry about PEFT checkpointing. Anything from the design specifically that stops us from doing this? In the worst case I can map |
One question I have about Given "Convert final checkpoint BACK to the original format before writing out to file", the existence of llama.cpp/convert.py is just a convenience util, right? I can take the "checkpoint that is converted back to original format" to llama.cpp repo as well, right? To me the latter feels more composable compare to writing thin wrapper. Thoughts? |
@kimishpatel Sorry I should clarify this point better.
|
README.md
Outdated
``` | ||
|
||
The argument passed to `--nproc_per_node` can be varied depending on how many GPUs you have. A full finetune can be memory-intensive, so make sure you are running on enough devices. See [this table](https://github.com/pytorch-labs/torchtune/blob/main/README.md#finetuning-resource-requirements) for resource requirements on common hardware setups. | ||
|
||
Similarly, you can finetune with LoRA on the Alpaca dataset on two devices via | ||
Similarly, you can finetune with LoRA on the Alpaca dataset on two devices via the following. Remember to convert your | ||
model with ```train_type``` set to ```lora'``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model with ```train_type``` set to ```lora'``` | |
model with `train_type` set to `lora` |
ckpts = ( | ||
["llama2.llama2_7b"] | ||
if large_scale | ||
else [ | ||
"small_test_ckpt_tune", | ||
"small_test_ckpt_hf", | ||
"small_test_ckpt_meta", | ||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should just be a @pytest.mark.parametrize
, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about that and I don't think that'll work when we pass in --large-scale? I actually don't know how that works, so went with a manual for loop. Let me know if you think that will work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm good point. I don't immediately see how to do this. Ideally we should do this more cleanly, but no need to block this PR on it
@@ -164,15 +228,21 @@ def test_gradient_accumulation( | |||
# We use a tiny model to reduce the error accumulation in the test | |||
# It's impossible to make a large model produce the same loss values | |||
# in the same way as the full batch size. | |||
model_ckpt = "llama2_tiny_test_ckpt" | |||
model_ckpt = "small_test_ckpt_tune" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this works? I thought we needed the tiny checkpoint for the accumulation of errors to be small enough
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, it works for me :)
cmd = f""" | ||
tune full_finetune \ | ||
--config {_CONFIG_PATH} \ | ||
--override \ | ||
model._component_=torchtune.models.{model_ckpt} \ | ||
model_checkpoint={fetch_ckpt_model_path(model_ckpt)} \ | ||
checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you could just write a single base command, then append micro_batch_size cfg for the first invocation and gradient_accumulation_steps cfg for the second invocation, just to make it clearer that's all that's changing. But obv it was in this state before you got here so not a huge deal if you don't do this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll follow up with changs beyond this PR in a separate PR. It's complex as it is
assert len(output_state_dict_1.keys()) + 1 == len(orig_state_dict_1.keys()) | ||
assert len(output_state_dict_2.keys()) + 1 == len(orig_state_dict_2.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inv freq accounts for the +1 here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we have inv_freq
in each layer. Since each dict corresponds to a single layer, both of them are impacted by 1.
|
||
class FullModelTorchTuneCheckpointer(_CheckpointerInterface): | ||
""" | ||
Checkpointer which reads and writes "full-model" checkpoints in a format compatible with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think the definition of "full-model" is not sufficiently clear from the context here
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, checkpoint_files[0]) | ||
if ( | ||
not self._checkpoint_path.is_file() | ||
or not self._checkpoint_path.suffix == ".pt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious: is the .pt extension a hard requirement? Clearly we've had non-pt extensions floating around for a while without really breaking anything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its very unintuitive, I'd just make this a convention for checkpoints we write
# if resume_from_checkpoint is True, recipe_state.pt should contain the recipe state | ||
if self._resume_from_checkpoint: | ||
self._recipe_state_file = get_recipe_checkpoint_path( | ||
self._checkpoint_dir, filename="recipe_state.pt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would consider defining recipe_state.pt
as a constant somewhere (maybe class-level) for increased visibility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I know you mentioned saving as a JSON previously, curious if there's any particular reason for switching to .pt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point about the constants. Let me generalize this a bit when I do the LoRA PR.
Optim state can be quite heavy (as large as model checkpoints) and so saving as json was a bad idea.
Model: | ||
{ | ||
"key_1": weight | ||
... | ||
} | ||
|
||
Recipe State: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would maybe explicitly add a comment about the filename each of these is saved to (I know you already mention recipe_state.pt
above). At first glance I kinda thought this was all a single state dict, I think adding filenames explicitly may make it harder to make that mistake.
self._resume_from_checkpoint = resume_from_checkpoint | ||
|
||
# weight_map contains the state_dict key -> checkpoint file mapping so we can correctly | ||
# parition the state dict into output checkpoint files. This is updated during checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
# parition the state dict into output checkpoint files. This is updated during checkpoint | |
# partition the state dict into output checkpoint files. This is updated during checkpoint |
f"Found {type(value)} instead." | ||
) | ||
# idx is written in the 4 digit format (eg: 0001, 0002, etc.) | ||
self._weight_map[key] = f"{cpt_idx+1:04}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a dumb q: why don't we just use the filename directly as the key here rather than going through all the sorting and indexing business?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the file names here are of the form pytorch_model-00001-of-00003.bin
, pytorch_model-00002-of-00003.bin
and so on. The ID (00001) is important since the weights are written in this order. But the entire filename doesnt serve any purpose (for now). So I simplified all of the filename logic and just added the ID by making sure the incoming names are lexicographically sorted.
num_heads=self._config["num_attention_heads"], | ||
num_kv_heads=self._config["num_key_value_heads"], | ||
dim=self._config["hidden_size"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these keys always exist?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For llama models they always do. I checked a few others and they do as well. But might be a good idea to add a try-catch block here
# If the recipe state needs to be output, first remove the model state dict | ||
if intermediate_checkpoint: | ||
_ = state_dict.pop("model") | ||
torch.save(state_dict, Path.joinpath(self._output_dir, "recipe_state.pt")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to check my understanding: when saving intermediate checkpoints, we still save in the input format (as opposed to TorchTune format, which we were doing previously)? And then we just supplement with the additional recipe_state.pt
file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeh thats right. This way if I want to run inference or eval on intermediate checkpoints, I dont need to do some conversion. Also it simplifies save significantly which is nice
Checkpointer which reads and writes "full-model" checkpoints in Meta's format. Example includes | ||
the Llama-2-7b model from the meta-llama repo (https://huggingface.co/meta-llama/Llama-2-7b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same link as the HF checkpointer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed those last two letters..
|
||
def load_checkpoint(self) -> Dict[str, Any]: | ||
""" | ||
Load TorchTune checkpoint from file. Currently only loading from a single file is supported. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here, is this line copy-paste from the equivalent in FullModelTorchTuneCheckpointer
? Just generally, make sure to update docstrings for all of these methods
} | ||
""" | ||
state_dict: Dict[str:Any] = {} | ||
state_dict["model"] = safe_torch_load(self._checkpoint_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we talked about getting rid of the "model" key? I might be missing something/misremembering though, lmk if so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get rid of "model" for checkpoint save. For the state dicts being send to the recipe we still need this to have a single dict with both model and recipe state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a bunch more comments, but I think all the major design concerns from my side are addressed in this latest version. So modulo my open comments, this looks good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a legendary PR, as usual! Couple of points I want to make sure we discuss prior to landing (also mostly okay with discussing these post land if absolutely needed, save for the question around security and extending to sharded checkpoints):
- User having to specify the
train-type
when converting their ckpts into our training format seems a bit unintuitive to me. Could you clarify we we have to do this for the moment, what's blocking us from not having to specify this, and how we can get there? - Seems like we would need to write a LoRA specific checkpointer for this to work with LoRA. Is the reason the same as (1), and can we have a discussion here?
- Curious about the
del state_dict
and gc collect - were you seeing memory issues here? - Not directly related to this PR, but I'm realizing having
tune download
+ operating on the checkpoints could introduce a security risk if we allow users to specify arbitrary checkpoint paths - we're downloading somewhat arbitrary checkpoint files from a third-party and running python code on them. This could also be on sensitive HW such as company devices. See https://www.darkreading.com/application-security/hugging-face-ai-platform-100-malicious-code-execution-models for potential attacks. Any thoughts on this security risk and how we can mitigate? One super easy win seems to just be to have an allowlist of checkpoints we've vetted, and crashtune download
on any other checkpoints by default? - Extensiblity to sharded checkpoints. Details on this point are in the CR comments.
@@ -56,6 +56,9 @@ jobs: | |||
mkdir -p /tmp/test-artifacts | |||
aws s3 cp s3:https://pytorch-multimodal/llama2-7b/tokenizer.model /tmp/test-artifacts | |||
aws s3 cp s3:https://pytorch-multimodal/llama2-7b-01242024 /tmp/test-artifacts | |||
aws s3 cp s3:https://pytorch-multimodal/small-ckpt-hf-03082024.pt /tmp/test-artifacts | |||
aws s3 cp s3:https://pytorch-multimodal/small-ckpt-tune-03082024.pt /tmp/test-artifacts | |||
aws s3 cp s3:https://pytorch-multimodal/small-ckpt-meta-03082024.pt /tmp/test-artifacts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to get pretty annoying to hardcode this tempfile directory, as users on the same box can overwrite each other's stuff / not have access to this directory. We should at least add some sort of unique id to this.
parser.add_argument( | ||
"--train-type", | ||
type=str, | ||
help="Type of finetuning. Currently Full-Finetuning and LoRA have slightly different formats. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I'd like to discuss and understand more deeply around this. User having to specify the train type into checkpoint conversion is quite unintuitive, and ideally checkpoint conversion shouldn't have to know about whether the model is going to be used for a full or LoRA finetune at all - it should just produce a checkpoint format that any torchtune training recipe (at least the ones that we write and endorse) can consume.
It also presents additional overhead to the user - I have to run separate conversion scripts to use different finetune techniques - I know currently it's a one time overhead, but worth pointing out + overhead will scale as we introduce more models.
I'm not able to tell from the PR description the blocker in just enabling a consistent format for both finetune techniques - mind elaborating? thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a temp hack till I fix LoRA. This PR just makes the change for full finetune. I'll follow up shortly with a PR that fixes this for LoRA and will remove this. I'll add this to the PR description
@@ -33,6 +34,7 @@ def convert_checkpoint( | |||
checkpoint_path (Path): Path to the checkpoint path. | |||
model (str): Model name | |||
output_path (Optional[Path]): Path to the output checkpoint. | |||
train_type (str): Type of finetuning |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add "must be full or lora"?
checkpointer: | ||
_component_: torchtune.utils.FullModelTorchTuneCheckpointer | ||
checkpoint_dir: /tmp/llama2/ | ||
checkpoint_files: [llama2_native.pt] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the file we will write out? what's the plan for scaling this when we need to produce checkpoint shards (at least for intermediate checkpoints)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the file we read from. The output file will be written to output_dir
and matches the input format. See the checkpointer doc string
_component_: torchtune.utils.FullModelTorchTuneCheckpointer | ||
checkpoint_dir: /tmp/llama2/ | ||
checkpoint_files: [llama2_native.pt] | ||
model_type: LLAMA2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where can I, as a user, find the supported "model_type"'s?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add more info on this after the LoRA change. For now its a copy paste for the config
|
||
def load_checkpoint(self) -> Dict[str, Any]: | ||
""" | ||
Load TorchTune checkpoint from file. Currently only loading from a single file is supported. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we mean load a meta checkpoint?
intermediate_checkpoint: bool = False, | ||
) -> None: | ||
""" | ||
Save TorchTune checkpoint to file. If ``intermediate_checkpoint`` is True, an additional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Save torchtune checkpoint to file, in meta format"?
torch.save(state_dict, Path.joinpath(self._output_dir, "recipe_state.pt")) | ||
|
||
|
||
class FullModelMetaCheckpointer(_CheckpointerInterface): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great but I'm curous about the utility of the save
portion of this file. Are meta checkpoints really used for any off ramps at the moment? They might be and I'm just unaware - do you have a list of off ramps that directly consume the meta style checkpoints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Llama.cpp for example was built with the Meta format in mind and then extended to handle HF checkpoints. generally its a bad idea to not support the original model format. For future llama versions we'll jsut launch with this format
state_dict.update(recipe_state) | ||
return state_dict | ||
|
||
def save_checkpoint( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another comment about saving in meta format. Do we anticipate use cases where user wants to load meta checkpoint --> train in torchtune --> output HF checkpoint? This seems decently natural if HF checkpoint format supports a lot of off ramps, should we support this, or is there a way user can work around to implement this with this work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I definitely don't want to include the cross conversion complexity here. We can think about providing some conversion scripts for this scenario
TORCHTUNE_RESTART = "torchtune_restart" | ||
|
||
|
||
class ModelType(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't fully follow why LoRA should have a different checkpoint, any details we can get here?
Thanks for the detailed reviews @ebsmothers and @rohan-varma.
This is just a hack since I dont have LoRA addressed in this PR. I'll work on that later today and remove this
Not for the MVP. We'll assume we always merge the weights and write a full checkpoint. Once we add adapter checkpointing we'll need to alter save and load and I'd rather have this as a separate class than to just add complexity to the full checkpointer
This needs a bit more testing before I can remove this. Seems like no harm to leave it there for now?
It's unclear. We can restrict it, but its an arbitrary restriction and is annoying. Currently we restrict and I just go uncomment this. But not sure what the right way to handle this is other than just letting Hub handle this. We can also support safe tensors, but I dont want to take the transformers depedency.
Responded |
Updates since last review:
Since the last review I've updated our checkpointing stack. The changes include:
FullModelTorchTuneCheckpointer
,FullModelMetaCheckpointer
andFullModelHFCheckpointer
which handle all of the logic associated with Meta, HF and TorchTune checkpoints. This drastically simplifies the recipe UX (look atfull_finetune.py
model
as a key even for final checkpoints. This is unintuitive and different from how any other lib and framework does this.PR Context
Note: A couple of notes as you read through this RFC:
Context
Building interoperability with the surrounding ecosystem is critical for TorchTune. This "off-ramp" i.e. the ease with which users can use the fine-tuned checkpoints with their favorite tools, is as important as the fine-tuning capabilities provided by the library. It's not an exaggeration to say that without having a strong interop story, it'll be hard for TorchTune to gain traction within the community.
Understanding the Landscape
Before we go deeper into building interoperability with the ecosystem, let's take a quick look at the current ecosystem.
HF Model Hub is the de-facto source for most (if not all) popular LLMs. Each model is associated with a checkpoint format. The checkpoint format is different from inference-time formats like GGUF. These refer to the model state dict and how it's presented to model users. At a high-level, checkpoint formats can be divided into two popular buckets:
.pth
file through the meta-llama repository on HF Model Hub. Various tools like llama.cpp directly build on top of this checkpoint by assuming the keys have a certain format..bin
files (or stored as safetensors) with an associated index.json file which provides information for building these state dicts back up.Popular code bases like gpt-fast [script], GPTQ-for-Llama [script], llama.cpp [script] etc all depend on the above formats or provide the option to write custom convertors.
Given the above state, my claim is that we should build TorchTune to be "state-dict invariant" i.e.
convert checkpoints from popular formats into TorchTune's format
->train
->convert back to the original format
. The rest of this RFC goes over this idea.But a few FAQs before that:
The TorchTune modeling components and classes are built with modularity and flexibility in mind. Using Transformers negates this design principle and takes away "extensibility" as one of our core value proposition to users. It also negates our goal of being "native PyTorch" since these frameworks and libraries have strict structure which needs to be followed. gpt-fast has a similar structure where the code base first [converts the cpt].
The above applies to both Mistral [HF Repo] and Gemma [HF Repo].
What does "be state-dict invariant" mean for TorchTune?
This has a sizable impact on our current user experience, but the ROI is high since not only do we get a "built-in off-ramp", but adding new models becomes easier.
Our current flow looks something like this:
meta-llama/Llama-2-7b
repo)The above flow means that for inference, we need to first convert the final checkpoint into a standard format (eg: GGUF or an Executorch-friendly-format) by writing a custom convertor, which can be substantial work [example]. Alternatively, we need to adopt a standard implementation of popular models which is also a no-go as mentioned above. As a result, adding new model support will be slow since we will need to build a new off-ramp for each model implementation.
The flow proposed by this RFC looks something like this:
To minimize cognitive load on the users, TorchTune recipes will handle conversions "to" and "from" the above formats instead of delegating this to the cli tool.
Concretely, the user would run the following:
With the above flow, I'm directly able to convert the model into GGUF to run inference and quantization using llama.cpp as well as use gpt-fast for running generation and quantization.
What changes in code?
The user no longer needs to know about TorchTune's checkpoint format unless they're resuming training from a previously failed run. Specifically, we:
META_FORMAT
forllama-2-7b
which includes authoring theconvert_llama2_from_meta_format
andconvert_llama2_to_meta_format
functions for translating state dicts.CheckpointFormat
through the recipe and config and explicitly ask the user to specify the checkpoint format (detailed documentation will make this clear)utils.load_checkpoint
to extract the state_dict from the original checkpoint rather than the translated checkpoint. This means adding support for both checkpoint files and checkpoint directories (checkpoints can be split across multiple files) which is done through the_fetch_meta_format_state_dict
function which inturn is hooked up toutils.load_checkpoint
through theload_external_checkpoint
functionload_checkpoint
in the recipe to translate state dict from original format to TorchTune's format. Behavior for resuming training remains the same (see below)resume_from_checkpoint
isTrue
. We need the additional information in the state dictsave_checkpoint
in the recipe to translate final state dict back to the original format.How does this make things better?
We get two big advantages.
Adding support for a new format is straight forward
For adding the
HF_FORMAT
and opening up the repo to a large number of llama-7b models trained usingtransformers
, we need to only:convert_llama2_from_hf_format
andconvert_llama2_to_hf_format
functions for translating state dicts._fetch_hf_format_state_dict
which contains logic for extracting state dicts from multiple.bin
filesWith the
HF_FORMAT
added, support for llama2 13B and 30B models (at least on an 80GB A100) should be straight forward as well. Adding Mistral-7B and Gemma-7B should not be as much work as before either.Running your fav inference tools is straight forward
The following conversion to GGUF using llama.cpp works OOTB
where
finetuned_llama
contains the output checkpoints from a full-finetuning run in TorchTune.Other FAQs
For the MVP, we plan to provide "merged weights" (inference-time weight merging and downstream support is out-of-scope). This should be a straight-forward swap since the state dict keys remain the same.
Great question! A couple of reasons for this:
For now, take my word that it words! I'm working on updating the tests, adding detailed doc strings and even a tutorial on the overall flow. But before I put in all of that work, I'd like some initial feedback.