Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: add tagging support to axolotl #1004

Merged
merged 3 commits into from
Dec 27, 2023

Conversation

younesbelkada
Copy link
Contributor

Hi there,

I would like to introduce a feature request to axolotl; automatic tagging when users push the trained model on the Hub. That way, you can easily filter on 馃 Hub models that has been trained with axolotl with a simple filter: https://huggingface.co/models?other=axolotl&sort=created

Similar feature as: huggingface/trl#1133 that we added in TRL library

cc @winglian

@younesbelkada
Copy link
Contributor Author

@winglian , shall we both quickly test if the tags work? I suggest we simply init a xxxTrainer and just call trainer.push_to_hub() on a small model

@winglian
Copy link
Collaborator

@winglian , shall we both quickly test if the tags work? I suggest we simply init a xxxTrainer and just call trainer.push_to_hub() on a small model

that would be great if you could test that to verify please.

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Dec 27, 2023

Thanks @winglian
I just tested it out for all axolotl trainer classes using this script for completeness using the latest commit from this branch:

from transformers import TrainingArguments, AutoModelForCausalLM

from axolotl.core.trainer_builder import AxolotlTrainer, AxolotlMambaTrainer, OneCycleLRSchedulerTrainer, ReLoRATrainer

trainer_classes = [AxolotlTrainer, AxolotlMambaTrainer, OneCycleLRSchedulerTrainer, ReLoRATrainer]

model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")

for trainer_cls in trainer_classes:
    args = TrainingArguments(
        output_dir=f"test-axolotl-{trainer_cls.__name__.lower()}"
    )

    trainer = trainer_cls(
        args=args,
        model=model
    )

    trainer.push_to_hub()

And you can find below, the links to all pushed trainers:

As you can see, they all have the tag axolotl, with the last 3 trainers having each an extra tag (e.g. mamba, relora, etc.)

@winglian winglian merged commit db9094d into OpenAccess-AI-Collective:main Dec 27, 2023
4 checks passed
@tmm1
Copy link
Collaborator

tmm1 commented Dec 27, 2023

This is really great, thanks @younesbelkada !

@younesbelkada younesbelkada deleted the add-axolotl-tag branch December 27, 2023 23:12
@younesbelkada
Copy link
Contributor Author

Thanks @winglian @tmm1 !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants