Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Updated Loss metric to use required_output_keys #2027

Merged
merged 5 commits into from
Jun 3, 2021

Conversation

01-vyom
Copy link
Contributor

@01-vyom 01-vyom commented Jun 2, 2021

Fixes #1415

Description:
Similar to Metric, Loss metric now supports required_output_keys.

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

@github-actions github-actions bot added the module: metrics Metrics module label Jun 2, 2021
@01-vyom
Copy link
Contributor Author

01-vyom commented Jun 2, 2021

Following is the demo code to check the validity of the above change:

import torch
import torch.nn as nn
from torch.nn.functional import nll_loss

from ignite.metrics import Accuracy, Loss
from ignite.engine import create_supervised_evaluator

model = nn.Linear(10, 3)

metrics = {
    "Accuracy": Accuracy(),
    "Loss": Loss(nll_loss)
}

# global criterion kwargs
criterion_kwargs = {"reduction": 'sum'}
# criterion_kwargs = {}

evaluator = create_supervised_evaluator(
    model,
    metrics=metrics,
    output_transform=lambda x, y, y_pred: {
        "x": x, "y": y, "y_pred": y_pred, "criterion_kwargs": criterion_kwargs}
)
data = [
    (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))),
    (torch.rand(4, 10), torch.randint(0, 3, size=(4, ))),
    (torch.rand(4, 10), torch.randint(0, 3, size=(4, )))
]
res = evaluator.run(data)

As the required_output_keys contains criterion_kwargs, the user has to pass an empty dictionary for the original case with no criterion_kwargs for criterion.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jun 2, 2021

@01-vyom thanks for the draft PR !
Let's now add some tests here:

Instead of DummyMetric1, we can create a metric derived from Loss and override update method to check passed args similarly to here:

def update(self, output):
assert output == self.true_output

Please let me know if need more explanations.

We can also add your above demo example as an integration test to https://github.com/pytorch/ignite/blob/master/tests/ignite/metrics/test_loss.py

@01-vyom
Copy link
Contributor Author

01-vyom commented Jun 3, 2021

Ok, I will add a test_output_mapping as well as all other none_keys and wrong_keys test with a dummy Loss and an integration test similar to here:

def test_override_required_output_keys():

@01-vyom 01-vyom changed the title [WIP] ENH: Updated Loss metric to use required_output_keys ENH: Updated Loss metric to use required_output_keys Jun 3, 2021
ignite/metrics/loss.py Outdated Show resolved Hide resolved
ignite/metrics/loss.py Outdated Show resolved Hide resolved
ignite/metrics/loss.py Outdated Show resolved Hide resolved
@01-vyom 01-vyom requested a review from vfdev-5 June 3, 2021 19:39
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @01-vyom !

@vfdev-5 vfdev-5 enabled auto-merge (squash) June 3, 2021 20:11
@vfdev-5 vfdev-5 merged commit 786aea8 into pytorch:master Jun 3, 2021
@01-vyom 01-vyom deleted the loss-metric-1415 branch June 3, 2021 20:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: metrics Metrics module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Loss metric to use required_output_keys
3 participants