Skip to content

Commit

Permalink
Updated Readme and references
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianjav committed Feb 13, 2022
1 parent bbc104d commit 03372df
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
20 changes: 14 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Then you can simply use RotateOnly, RotoGrad. or RotoGradNorm (RotateOnly + Grad

```python
from rotograd import RotoGrad
model = RotoGrad(backbone, [head1, head2], size_z, alpha=1.)
model = RotoGrad(backbone, [head1, head2], size_z, normalize_losses=True)
```

where you can recover the backbone and i-th head simply calling `model.backbone` and `model.heads[i]`. Even
Expand Down Expand Up @@ -71,15 +71,23 @@ def step(x, y1, y2):
return loss1, loss2
```

## Example

You can find a working example in the folder `example`. However, it requires some other dependencies to run (e.g.,
ignite and seaborn). The examples shows how to use RotoGrad on one of the regression problems from the manuscript.

![image](_assets/toy.gif)

## Citing

Consider citing the following paper if you use RotoGrad:

```bibtex
@article{javaloy2021rotograd,
title={RotoGrad: Gradient Homogenization in Multitask Learning},
author={Javaloy, Adri\'an and Valera, Isabel},
journal={arXiv preprint arXiv:2103.02631},
year={2021}
@inproceedings{javaloy2022rotograd,
title={RotoGrad: Gradient Homogenization in Multitask Learning},
author={Adri{\'a}n Javaloy and Isabel Valera},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=T8wHz4rnuGL}
}
```
Binary file added _assets/toy.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions rotograd/rotograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class RotateOnly(nn.Module):
References
----------
.. [1] Javaloy, Adrián, and Isabel Valera. "RotoGrad: Gradient Homogenization in Multitask Learning."
arXiv preprint arXiv:2103.02631 (2021).
International Conference on Learning Representations (2022).
"""
num_tasks: int
Expand Down Expand Up @@ -408,7 +408,7 @@ class RotoGrad(RotateOnly):
References
----------
.. [1] Javaloy, Adrián, and Isabel Valera. "RotoGrad: Gradient Homogenization in Multitask Learning."
arXiv preprint arXiv:2103.02631 (2021).
International Conference on Learning Representations (2022).
"""
num_tasks: int
Expand Down Expand Up @@ -479,7 +479,7 @@ class RotoGradNorm(RotoGrad):
References
----------
.. [1] Javaloy, Adrián, and Isabel Valera. "RotoGrad: Gradient Homogenization in Multitask Learning."
arXiv preprint arXiv:2103.02631 (2021).
International Conference on Learning Representations (2022).
.. [2] Chen, Zhao, et al. "Gradnorm: Gradient normalization for adaptive loss balancing in deep multitask networks."
International Conference on Machine Learning. PMLR, 2018.
Expand Down

0 comments on commit 03372df

Please sign in to comment.