Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instantiate model from automodel #601

Merged

Conversation

svenhendrikx
Copy link

This pull request implements #521, adding functionality to allow users to run lm-eval tasks directly on transformers.PreTrainedModel instances using simple_evaluate. It contains three commits:

  1. Add logic to the HFLM class, such that a transformers.PreTrainedModel instance can be passed as the pretrained argument, as @haileyschoelkopf suggested here: Add a way to instantiate from HF.AutoModel #521

  2. Add an init.py file to the bigbench_resources directory, such that it is included in the build. If you don't do this, you'll get an error when trying to import the tasks, if you install the package using pip and the GitHub link.

  3. Add logic to simple_evaluate, such that you can pass it a transformers.PreTrainedModel instance as well. I chose to directly instantiate the object, whereas when a string is passed, the models.get_model function is used. Directly instantiating it seemed like the simplest solution, but you could also add the functionality to the get_model function. Let me know what you think.

This is my first contribution to lm-eval, so feel free to share tips

@CLAassistant
Copy link

CLAassistant commented Jun 18, 2023

CLA assistant check
All committers have signed the CLA.

@@ -60,8 +60,8 @@ def __init__(
trust_remote_code=trust_remote_code,
)


else:
elif isinstance(pretrained, str):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something I realized last minute, bit cleaner than using an assertion. The else block raises a TypeError, which is more descriptive.

198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tokenizer assert is no longer required! it's a holdover from earlier commits where this model type was assumed to be gpt2 if using a GPT2Tokenizer type.



# Initialize model
if isinstance(pretrained, transformers.PreTrainedModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit: it'd be nice if we could confirm this is of type AutoModelForCausalLM or related subclasses, since this LM subclass only assumes a causal decoder-only model type.

@@ -72,6 +76,11 @@ def simple_evaluate(
lm = lm_eval.models.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "max_batch_size": max_batch_size, "device": device}
)
elif isinstance(model, transformers.PreTrainedModel):
lm = HFLM(
pretrained=model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also want to pass batch_size=batch_size to this I believe. Agree that we should assume the user has already placed their model onto the correct device though!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lm = lm_eval.models.get_model("hf-causal") instead of instantiating HFLM directly here too preferably.

@haileyschoelkopf
Copy link
Contributor

haileyschoelkopf commented Jun 27, 2023

Thanks so much for this PR, and I apologize for the slow review! It looks great, but left a couple minor nitpicks. Happy to return to these later today to fix and merge. EDIT: Testing this now.

@haileyschoelkopf
Copy link
Contributor

PRed changes to your branch here: svenhendrikx#1 once these are merged, LGTM!

@haileyschoelkopf
Copy link
Contributor

Thanks again for the contribution!!

@haileyschoelkopf haileyschoelkopf merged commit 72b7f0c into EleutherAI:master Jun 27, 2023
2 checks passed
qmdnls pushed a commit to qmdnls/lm-evaluation-harness that referenced this pull request Aug 17, 2023
…-from-Automodel

Instantiate model from automodel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants