Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add align_labels_with_mapping function #2457

Merged
merged 6 commits into from
Jun 17, 2021

Conversation

lewtun
Copy link
Member

@lewtun lewtun commented Jun 8, 2021

This PR adds a helper function to align the label2id mapping between a datasets.Dataset and a classifier (e.g. a transformer with a PretrainedConfig.label2id dict), with the alignment performed on the dataset itself.

This will help us with the Hub evaluation, where we won't know in advance whether a model that is fine-tuned on say MNLI has the same mappings as the MNLI dataset we load from datasets.

An example where this is needed is if we naively try to evaluate microsoft/deberta-base-mnli on mnli because the model config has the following mappings:

  "id2label": {
    "0": "CONTRADICTION",
    "1": "NEUTRAL",
    "2": "ENTAILMENT"
  },
  "label2id": {
    "CONTRADICTION": 0,
    "ENTAILMENT": 2,
    "NEUTRAL": 1
  }

while the mnli dataset has the contradiction and neutral labels swapped:

id2label = {0: 'entailment', 1: 'neutral', 2: 'contradiction'}
label2id = {'contradiction': 2, 'entailment': 0, 'neutral': 1}

As a result, we get a much lower accuracy during evaluation:

from datasets import load_dataset
from transformers.trainer_utils import EvalPrediction
from transformers import AutoModelForSequenceClassification, Trainer

# load dataset for evaluation
mnli = load_dataset("glue", "mnli", split="test")
# load model
model_ckpt = "microsoft/deberta-base-mnli"
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
# preprocess, create trainer ...
mnli_enc = ...
trainer = Trainer(model, args=args, tokenizer=tokenizer)
# generate preds
preds = trainer.predict(mnli_enc)
# preds.label_ids misalinged with model.config => returns wrong accuracy (too low)!
compute_metrics(EvalPrediction(preds.predictions, preds.label_ids))

The fix is to use the helper function before running the evaluation to make sure the label IDs are aligned:

mnli_enc_aligned = mnli_enc.align_labels_with_mapping(label2id=config.label2id, label_column="label")
# preds now aligned and everyone is happy :)
preds = trainer.predict(mnli_enc_aligned)

cc @thomwolf @lhoestq

@lewtun lewtun marked this pull request as ready for review June 10, 2021 09:31
@lewtun lewtun changed the title [WIP] Add align_labels_with_mapping function Add align_labels_with_mapping function Jun 10, 2021
@lewtun lewtun requested a review from lhoestq June 10, 2021 09:31
@lewtun
Copy link
Member Author

lewtun commented Jun 10, 2021

@lhoestq i think this is ready for another review 🙂

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this method and writing the test !
Just have 2 comments:

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Show resolved Hide resolved
@lewtun
Copy link
Member Author

lewtun commented Jun 17, 2021

@lhoestq thanks for the feedback - it's now integrated :)

i also added a comment about sorting the input label IDs

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Looks all good now :)

We will also need to have the DatasetDict.align_labels_with_mapping method. Let me quickly add it

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
@lhoestq lhoestq merged commit 28db79d into huggingface:master Jun 17, 2021
@lhoestq
Copy link
Member

lhoestq commented Jun 17, 2021

Created the PR here: #2510

@lewtun
Copy link
Member Author

lewtun commented Jun 17, 2021

Thanks ! Looks all good now :)

We will also need to have the DatasetDict.align_labels_with_mapping method. Let me quickly add it

thanks a lot! i always forget about DatasetDict - will be happy when it's just one "dataset" object :)

@lewtun lewtun deleted the align-labels-with-mapping branch June 17, 2021 10:17
JayantGoel001 added a commit to JayantGoel001/datasets-1 that referenced this pull request Jun 17, 2021
Add align_labels_with_mapping function (huggingface#2457)
@retoj
Copy link

retoj commented Jan 12, 2022

So, there seems to be a problem with the function align_labels_with_mapping for models like this: https://huggingface.co/huggingface/distilbert-base-uncased-finetuned-mnli]. At least with this model, but perhaps also with others, the model.config.label2id values are of type str not int, which crashes said function. After manually converting the model.config.label2id values to int, the script runs smoothly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants