diff --git a/README.md b/README.md index 0217cbbd..f60329cd 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **68 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported! +Currently, **69 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -165,6 +165,7 @@ supported_optimizers = get_supported_optimizers() | bSAM | *SAM as an Optimal Relaxation of Bayes* | [github](https://github.com/team-approx-bayes/bayesian-sam) | | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv221001620M/exportcitation) | | Schedule-Free | *Schedule-Free Optimizers* | [github](https://github.com/facebookresearch/schedule_free) | | [cite](https://github.com/facebookresearch/schedule_free) | | FAdam | *Adam is a natural gradient optimizer using diagonal empirical Fisher information* | [github](https://github.com/lessw2020/fadam_pytorch) | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240512807H/exportcitation) | +| Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) | ## Supported LR Scheduler @@ -325,7 +326,7 @@ If you use this software, please cite it below. Or you can get it from "cite thi month = jan, title = {{pytorch_optimizer: optimizer & lr scheduler & loss function collections in PyTorch}}, url = {https://github.com/kozistr/pytorch_optimizer}, - version = {2.12.0}, + version = {3.0.1}, year = {2021} } diff --git a/docs/changelogs/v3.0.1.md b/docs/changelogs/v3.0.1.md index 032f193a..969ef1bd 100644 --- a/docs/changelogs/v3.0.1.md +++ b/docs/changelogs/v3.0.1.md @@ -8,6 +8,8 @@ * support not-using-first-momentum when beta1 is not given * default dtype for first momentum to `bfloat16` * clip second momentum to 0.999 +* Implement `GrokFast` optimizer. (#244, #245) + * [Accelerated Grokking by Amplifying Slow Gradients](https://arxiv.org/abs/2405.20233) ### Bug diff --git a/docs/index.md b/docs/index.md index 0217cbbd..f60329cd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **68 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported! +Currently, **69 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -165,6 +165,7 @@ supported_optimizers = get_supported_optimizers() | bSAM | *SAM as an Optimal Relaxation of Bayes* | [github](https://github.com/team-approx-bayes/bayesian-sam) | | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv221001620M/exportcitation) | | Schedule-Free | *Schedule-Free Optimizers* | [github](https://github.com/facebookresearch/schedule_free) | | [cite](https://github.com/facebookresearch/schedule_free) | | FAdam | *Adam is a natural gradient optimizer using diagonal empirical Fisher information* | [github](https://github.com/lessw2020/fadam_pytorch) | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240512807H/exportcitation) | +| Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) | ## Supported LR Scheduler @@ -325,7 +326,7 @@ If you use this software, please cite it below. Or you can get it from "cite thi month = jan, title = {{pytorch_optimizer: optimizer & lr scheduler & loss function collections in PyTorch}}, url = {https://github.com/kozistr/pytorch_optimizer}, - version = {2.12.0}, + version = {3.0.1}, year = {2021} } diff --git a/docs/optimizer.md b/docs/optimizer.md index 523fc1a6..94f94f93 100644 --- a/docs/optimizer.md +++ b/docs/optimizer.md @@ -156,6 +156,18 @@ :docstring: :members: +::: pytorch_optimizer.gradfilter_ema + :docstring: + :members: + +::: pytorch_optimizer.gradfilter_ma + :docstring: + :members: + +::: pytorch_optimizer.GrokFastAdamW + :docstring: + :members: + ::: pytorch_optimizer.GSAM :docstring: :members: diff --git a/poetry.lock b/poetry.lock index 3265a114..3e213bca 100644 --- a/poetry.lock +++ b/poetry.lock @@ -173,29 +173,29 @@ test = ["pytest (>=6)"] [[package]] name = "filelock" -version = "3.14.0" +version = "3.15.3" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"}, - {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"}, + {file = "filelock-3.15.3-py3-none-any.whl", hash = "sha256:0151273e5b5d6cf753a61ec83b3a9b7d8821c39ae9af9d7ecf2f9e2f17404103"}, + {file = "filelock-3.15.3.tar.gz", hash = "sha256:e1199bf5194a2277273dacd50269f0d87d0682088a3c561c15674ea9005d8635"}, ] [package.extras] docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] name = "fsspec" -version = "2024.5.0" +version = "2024.6.0" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.5.0-py3-none-any.whl", hash = "sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c"}, - {file = "fsspec-2024.5.0.tar.gz", hash = "sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a"}, + {file = "fsspec-2024.6.0-py3-none-any.whl", hash = "sha256:58d7122eb8a1a46f7f13453187bfea4972d66bf01618d37366521b1998034cee"}, + {file = "fsspec-2024.6.0.tar.gz", hash = "sha256:f579960a56e6d8038a9efc8f9c77279ec12e6299aa86b0769a7e9c46b94527c2"}, ] [package.extras] @@ -204,6 +204,7 @@ adl = ["adlfs"] arrow = ["pyarrow (>=1)"] dask = ["dask", "distributed"] dev = ["pre-commit", "ruff"] +doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"] dropbox = ["dropbox", "dropboxdrivefs", "requests"] full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] fuse = ["fusepy"] @@ -453,13 +454,13 @@ files = [ [[package]] name = "packaging" -version = "24.0" +version = "24.1" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, - {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] @@ -506,13 +507,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pytest" -version = "8.2.1" +version = "8.2.2" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-8.2.1-py3-none-any.whl", hash = "sha256:faccc5d332b8c3719f40283d0d44aa5cf101cec36f88cde9ed8f2bc0538612b1"}, - {file = "pytest-8.2.1.tar.gz", hash = "sha256:5046e5b46d8e4cac199c373041f26be56fdb81eb4e67dc11d4e10811fc3408fd"}, + {file = "pytest-8.2.2-py3-none-any.whl", hash = "sha256:c434598117762e2bd304e526244f67bf66bbd7b5d6cf22138be51ff661980343"}, + {file = "pytest-8.2.2.tar.gz", hash = "sha256:de4bb8104e201939ccdc688b27a89a7be2079b22e2bd2b07f806b6ba71117977"}, ] [package.dependencies] @@ -546,28 +547,28 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] [[package]] name = "ruff" -version = "0.4.7" +version = "0.4.10" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.4.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e089371c67892a73b6bb1525608e89a2aca1b77b5440acf7a71dda5dac958f9e"}, - {file = "ruff-0.4.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:10f973d521d910e5f9c72ab27e409e839089f955be8a4c8826601a6323a89753"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59c3d110970001dfa494bcd95478e62286c751126dfb15c3c46e7915fc49694f"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fa9773c6c00f4958f73b317bc0fd125295110c3776089f6ef318f4b775f0abe4"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07fc80bbb61e42b3b23b10fda6a2a0f5a067f810180a3760c5ef1b456c21b9db"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:fa4dafe3fe66d90e2e2b63fa1591dd6e3f090ca2128daa0be33db894e6c18648"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7c0083febdec17571455903b184a10026603a1de078428ba155e7ce9358c5f6"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ad1b20e66a44057c326168437d680a2166c177c939346b19c0d6b08a62a37589"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbf5d818553add7511c38b05532d94a407f499d1a76ebb0cad0374e32bc67202"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:50e9651578b629baec3d1513b2534de0ac7ed7753e1382272b8d609997e27e83"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8874a9df7766cb956b218a0a239e0a5d23d9e843e4da1e113ae1d27ee420877a"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:b9de9a6e49f7d529decd09381c0860c3f82fa0b0ea00ea78409b785d2308a567"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:13a1768b0691619822ae6d446132dbdfd568b700ecd3652b20d4e8bc1e498f78"}, - {file = "ruff-0.4.7-py3-none-win32.whl", hash = "sha256:769e5a51df61e07e887b81e6f039e7ed3573316ab7dd9f635c5afaa310e4030e"}, - {file = "ruff-0.4.7-py3-none-win_amd64.whl", hash = "sha256:9e3ab684ad403a9ed1226894c32c3ab9c2e0718440f6f50c7c5829932bc9e054"}, - {file = "ruff-0.4.7-py3-none-win_arm64.whl", hash = "sha256:10f2204b9a613988e3484194c2c9e96a22079206b22b787605c255f130db5ed7"}, - {file = "ruff-0.4.7.tar.gz", hash = "sha256:2331d2b051dc77a289a653fcc6a42cce357087c5975738157cd966590b18b5e1"}, + {file = "ruff-0.4.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c2c4d0859305ac5a16310eec40e4e9a9dec5dcdfbe92697acd99624e8638dac"}, + {file = "ruff-0.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a79489607d1495685cdd911a323a35871abfb7a95d4f98fc6f85e799227ac46e"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1dd1681dfa90a41b8376a61af05cc4dc5ff32c8f14f5fe20dba9ff5deb80cd6"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c75c53bb79d71310dc79fb69eb4902fba804a81f374bc86a9b117a8d077a1784"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18238c80ee3d9100d3535d8eb15a59c4a0753b45cc55f8bf38f38d6a597b9739"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d8f71885bce242da344989cae08e263de29752f094233f932d4f5cfb4ef36a81"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:330421543bd3222cdfec481e8ff3460e8702ed1e58b494cf9d9e4bf90db52b9d"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e9b6fb3a37b772628415b00c4fc892f97954275394ed611056a4b8a2631365e"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f54c481b39a762d48f64d97351048e842861c6662d63ec599f67d515cb417f6"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:67fe086b433b965c22de0b4259ddfe6fa541c95bf418499bedb9ad5fb8d1c631"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:acfaaab59543382085f9eb51f8e87bac26bf96b164839955f244d07125a982ef"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3cea07079962b2941244191569cf3a05541477286f5cafea638cd3aa94b56815"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:338a64ef0748f8c3a80d7f05785930f7965d71ca260904a9321d13be24b79695"}, + {file = "ruff-0.4.10-py3-none-win32.whl", hash = "sha256:ffe3cd2f89cb54561c62e5fa20e8f182c0a444934bf430515a4b422f1ab7b7ca"}, + {file = "ruff-0.4.10-py3-none-win_amd64.whl", hash = "sha256:67f67cef43c55ffc8cc59e8e0b97e9e60b4837c8f21e8ab5ffd5d66e196e25f7"}, + {file = "ruff-0.4.10-py3-none-win_arm64.whl", hash = "sha256:dd1fcee327c20addac7916ca4e2653fbbf2e8388d8a6477ce5b4e986b68ae6c0"}, + {file = "ruff-0.4.10.tar.gz", hash = "sha256:3aa4f2bc388a30d346c56524f7cacca85945ba124945fe489952aadb6b5cd804"}, ] [[package]] @@ -586,15 +587,15 @@ mpmath = ">=1.1.0,<1.4.0" [[package]] name = "tbb" -version = "2021.12.0" +version = "2021.13.0" description = "IntelĀ® oneAPI Threading Building Blocks (oneTBB)" optional = false python-versions = "*" files = [ - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"}, - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"}, - {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"}, - {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, + {file = "tbb-2021.13.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:a2567725329639519d46d92a2634cf61e76601dac2f777a05686fea546c4fe4f"}, + {file = "tbb-2021.13.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:aaf667e92849adb012b8874d6393282afc318aca4407fc62f912ee30a22da46a"}, + {file = "tbb-2021.13.0-py3-none-win32.whl", hash = "sha256:6669d26703e9943f6164c6407bd4a237a45007e79b8d3832fe6999576eaaa9ef"}, + {file = "tbb-2021.13.0-py3-none-win_amd64.whl", hash = "sha256:3528a53e4bbe64b07a6112b4c5a00ff3c61924ee46c9c68e004a1ac7ad1f09c3"}, ] [[package]] @@ -610,21 +611,21 @@ files = [ [[package]] name = "torch" -version = "2.3.0+cpu" +version = "2.3.1+cpu" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.3.0+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:e3c220702d82c7596924150e0499fbbffcf62a88a59adc860fa357cd8dc1c302"}, - {file = "torch-2.3.0+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:ab0c05525195b8fecdf2ea75968ed32ccd87dff16381b6e13249babb4a9596ff"}, - {file = "torch-2.3.0+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:97a38b25ee0e3d020691e7846efbca62a3d8a57645c027dcb5ba0adfec36fe55"}, - {file = "torch-2.3.0+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:a8ac195974be6f067245bae8156b8c06fb0a723b0eed8f2e244b5dd58c7e2a49"}, - {file = "torch-2.3.0+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:a8982e52185771591dad577a124a7770f72f288f8ae5833317b1e329c0d2f07e"}, - {file = "torch-2.3.0+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:483131a7997995d867313ee902743084e844e830ab2a0c5e079c61ec2da3cd17"}, - {file = "torch-2.3.0+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:8c52484880d5fbe511cffc255dd34847ddeced3f94334c6bf7eb2b0445f10cb4"}, - {file = "torch-2.3.0+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:28a11bcc0d709b397d675cff689707019b8cc122e6bf328b57b900f47c36f156"}, - {file = "torch-2.3.0+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:1e86e225e472392440ace378ba3165b5e87648e8b5fbf16adc41c0df881c38b8"}, - {file = "torch-2.3.0+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:5c2afdff80203eaabf4c223a294c2f465020b3360e8e87f76b52ace9c5801ebe"}, + {file = "torch-2.3.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:d679e21d871982b9234444331a26350902cfd2d5ca44ce6f49896af8b3a3087d"}, + {file = "torch-2.3.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:500bf790afc2fd374a15d06213242e517afccc50a46ea5955d321a9a68003335"}, + {file = "torch-2.3.1+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:a272defe305dbd944aa28a91cc3db0f0149495b3ebec2e39723a7224fa05dc57"}, + {file = "torch-2.3.1+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:d2965eb54d3c8818e2280a54bd53e8246a6bb34e4b10bd19c59f35b611dd9f05"}, + {file = "torch-2.3.1+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:2141a6cb7021adf2f92a0fd372cfeac524ba460bd39ce3a641d30a561e41f69a"}, + {file = "torch-2.3.1+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:6acdca2530462611095c44fd95af75ecd5b9646eac813452fe0adf31a9bc310a"}, + {file = "torch-2.3.1+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:cab92d5101e6db686c5525e04d87cedbcf3a556073d71d07fbe7d1ce09630ffb"}, + {file = "torch-2.3.1+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:dbc784569a367fd425158cf4ae82057dd3011185ba5fc68440432ba0562cb5b2"}, + {file = "torch-2.3.1+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:a3cb8e61ba311cee1bb7463cbdcf3ebdfd071e2091e74c5785e3687eb02819f9"}, + {file = "torch-2.3.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:df68668056e62c0332e03f43d9da5d4278b39df1ba58d30ec20d34242070955d"}, ] [package.dependencies] @@ -647,13 +648,13 @@ reference = "torch" [[package]] name = "typing-extensions" -version = "4.12.1" +version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.12.1-py3-none-any.whl", hash = "sha256:6024b58b69089e5a89c347397254e35f1bf02a907728ec7fee9bf0fe837d203a"}, - {file = "typing_extensions-4.12.1.tar.gz", hash = "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [extras] diff --git a/pyproject.toml b/pyproject.toml index f324215e..24a57b70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,12 +13,12 @@ keywords = [ "pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", "AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "Adalite", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", - "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage", "GaLore", "Gravity", "GSAM", - "LARS", "Lamb", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", - "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", - "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH", "SRMM", "SWATS", - "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", - "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", + "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage", "GaLore", "Gravity", + "GrokFast", "GSAM", "LARS", "Lamb", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", + "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", + "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH", + "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", + "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", ] classifiers = [ "License :: OSI Approved :: Apache Software License", diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 03e2e63a..ad34e92c 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -60,6 +60,7 @@ from pytorch_optimizer.optimizer.galore import GaLore, GaLoreProjector from pytorch_optimizer.optimizer.gc import centralize_gradient from pytorch_optimizer.optimizer.gravity import Gravity +from pytorch_optimizer.optimizer.grokfast import GrokFastAdamW, gradfilter_ema, gradfilter_ma from pytorch_optimizer.optimizer.lamb import Lamb from pytorch_optimizer.optimizer.lars import LARS from pytorch_optimizer.optimizer.lion import Lion @@ -192,6 +193,7 @@ ScheduleFreeSGD, ScheduleFreeAdamW, FAdam, + GrokFastAdamW, ] OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} diff --git a/pytorch_optimizer/optimizer/grokfast.py b/pytorch_optimizer/optimizer/grokfast.py new file mode 100644 index 00000000..22b57a9d --- /dev/null +++ b/pytorch_optimizer/optimizer/grokfast.py @@ -0,0 +1,232 @@ +import math +from collections import deque +from typing import Dict, Literal, Optional + +import torch +from torch import nn +from torch.optim.optimizer import Optimizer + +from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS + +FILTER_TYPE = Literal['mean', 'sum'] + + +@torch.no_grad() +def gradfilter_ma( + model: nn.Module, + grads: Optional[Dict[str, deque]] = None, + window_size: int = 100, + lamb: float = 5.0, + filter_type: FILTER_TYPE = 'mean', + warmup: bool = True, +) -> Dict[str, deque]: + r"""Grokfast-MA. + + Example: + ------- + Here's an example:: + + loss.backwards() # Calculate the gradients. + + grads = gradfilter_ma(model, grads=grads, window_size=window_size, lamb=lamb) + + optimizer.step() # Call the optimizer. + + :param model: nn.Module. model that contains every trainable parameters. + :param grads: Optional[Dict[str, deque]]. running memory (Queue for windowed moving average). initialize by setting + it to None. feed the output of the method recursively after on. + :param window_size: int. the width of the filter window. additional memory requirements increases linearly with + respect to the windows size. + :param lamb: float. amplifying factor hyperparameter of the filter. + :param filter_type: FILTER_TYPE. aggregation method for the running queue. + :param warmup: bool. if true, filter is not applied until the queue is filled. + """ + if grads is None: + grads = {n: deque(maxlen=window_size) for n, p in model.named_parameters() if p.requires_grad} + + for n, p in model.named_parameters(): + if p.requires_grad: + grads[n].append(p.grad) + + if not warmup or len(grads[n]) == window_size: + if filter_type == 'mean': + avg = sum(grads[n]) / len(grads[n]) + elif filter_type == 'sum': + avg = sum(grads[n]) + else: + raise ValueError(f'not supported filter_type {filter_type}') + + p.grad.add_(avg, alpha=lamb) + + return grads + + +@torch.no_grad() +def gradfilter_ema( + model: nn.Module, + grads: Optional[Dict[str, torch.Tensor]] = None, + alpha: float = 0.98, + lamb: float = 2.0, +) -> Dict[str, torch.Tensor]: + r"""Grokfast. + + Example: + ------- + Here's an example:: + + loss.backwards() # Calculate the gradients. + + grads = gradfilter_ema(model, grads=grads, alpha=alpha, lamb=lamb) + + optimizer.step() # Call the optimizer. + + :param model: nn.Module. model that contains every trainable parameters. + :param grads: Optional[Dict[str, deque]]. running memory (EMA). Initialize by setting it to None. Feed the output + of the method recursively after on. + :param alpha: int. momentum hyperparameter of the EMA. + :param lamb: float. amplifying factor hyperparameter of the filter. + """ + if grads is None: + grads = {n: p.grad for n, p in model.named_parameters() if p.requires_grad} + + for n, p in model.named_parameters(): + if p.requires_grad: + grads[n].mul_(alpha).add_(p.grad, alpha=1.0 - alpha) + p.grad.add_(grads[n], alpha=lamb) + + return grads + + +class GrokFastAdamW(Optimizer, BaseOptimizer): + r"""Accelerated Grokking by Amplifying Slow Gradients with AdamW. + + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. + :param lr: float. learning rate. + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. + :param grokfast: bool. whether to use grokfast. + :param grokfast_alpha: float. momentum hyperparameter of the EMA. + :param grokfast_lamb: float. amplifying factor hyperparameter of the filter.. + :param grokfast_after_step: int. warmup step for grokfast. + :param weight_decay: float. weight decay (L2 penalty). + :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. + :param fixed_decay: bool. fix weight decay. + :param eps: float. term added to the denominator to improve numerical stability. + """ + + def __init__( + self, + params: PARAMETERS, + lr: float = 1e-4, + betas: BETAS = (0.9, 0.99), + grokfast: bool = True, + grokfast_alpha: float = 0.98, + grokfast_lamb: float = 2.0, + grokfast_after_step: int = 0, + weight_decay: float = 0.0, + weight_decouple: bool = True, + fixed_decay: bool = False, + normalize_lr: bool = True, + eps: float = 1e-8, + ): + self.validate_learning_rate(lr) + self.validate_betas(betas) + self.validate_non_negative(weight_decay, 'weight_decay') + self.validate_range(grokfast_alpha, 'grokfast_alpha', 0.0, 1.0) + self.validate_non_negative(eps, 'eps') + + if grokfast and normalize_lr: + lr /= 1.0 + grokfast_lamb + + defaults: DEFAULTS = { + 'lr': lr, + 'betas': betas, + 'weight_decay': weight_decay, + 'weight_decouple': weight_decouple, + 'fixed_decay': fixed_decay, + 'grokfast': grokfast, + 'grokfast_alpha': grokfast_alpha, + 'grokfast_lamb': grokfast_lamb, + 'grokfast_after_step': grokfast_after_step, + 'eps': eps, + } + super().__init__(params, defaults) + + def __str__(self) -> str: + return 'GrokFastAdamW' + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + state = self.state[p] + + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> LOSS: + loss: LOSS = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + beta1, beta2 = group['betas'] + + bias_correction1: float = 1.0 - beta1 ** group['step'] + bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + + should_grokfast: bool = ( + group['grokfast'] and group['step'] > group['grokfast_after_step'] and group['grokfast_lamb'] > 0 + ) + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + if grad.is_sparse: + raise NoSparseGradientError(str(self)) + + state = self.state[p] + + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + if should_grokfast: + state['grok_exp_avg'] = grad.clone() + + self.apply_weight_decay( + p=p, + grad=grad, + lr=group['lr'], + weight_decay=group['weight_decay'], + weight_decouple=group['weight_decouple'], + fixed_decay=group['fixed_decay'], + ) + + if should_grokfast: + grok_exp_avg = state['grok_exp_avg'] + grok_exp_avg.lerp_(grad, weight=1.0 - group['grokfast_alpha']) + + grad.add_(grok_exp_avg, alpha=group['grokfast_lamb']) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).clamp_(min=group['eps']) + + update = exp_avg.div(bias_correction1).div_(de_nom) + + p.add_(update, alpha=-group['lr']) + + return loss diff --git a/requirements-dev.txt b/requirements-dev.txt index 6a2ab780..033881f4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,8 +5,8 @@ click==8.1.7 ; python_version >= "3.8" and python_full_version < "4.0.0" colorama==0.4.6 ; python_version >= "3.8" and python_full_version < "4.0.0" and (sys_platform == "win32" or platform_system == "Windows") coverage[toml]==7.5.3 ; python_version >= "3.8" and python_full_version < "4.0.0" exceptiongroup==1.2.1 ; python_version >= "3.8" and python_version < "3.11" -filelock==3.14.0 ; python_version >= "3.8" and python_full_version < "4.0.0" -fsspec==2024.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +filelock==3.15.3 ; python_version >= "3.8" and python_full_version < "4.0.0" +fsspec==2024.6.0 ; python_version >= "3.8" and python_full_version < "4.0.0" iniconfig==2.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0" intel-openmp==2021.4.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" isort==5.13.2 ; python_version >= "3.8" and python_full_version < "4.0.0" @@ -17,15 +17,15 @@ mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0" mypy-extensions==1.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0" networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0" -packaging==24.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +packaging==24.1 ; python_version >= "3.8" and python_full_version < "4.0.0" pathspec==0.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0" platformdirs==4.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0" pluggy==1.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0" pytest-cov==5.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0" -pytest==8.2.1 ; python_version >= "3.8" and python_full_version < "4.0.0" -ruff==0.4.7 ; python_version >= "3.8" and python_full_version < "4.0.0" +pytest==8.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0" +ruff==0.4.10 ; python_version >= "3.8" and python_full_version < "4.0.0" sympy==1.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0" -tbb==2021.12.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" +tbb==2021.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6" -torch==2.3.0+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" -typing-extensions==4.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0" +torch==2.3.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" +typing-extensions==4.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0" diff --git a/requirements.txt b/requirements.txt index 9392e86f..c7556f23 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://download.pytorch.org/whl/cpu -filelock==3.14.0 ; python_version >= "3.8" and python_full_version < "4.0.0" -fsspec==2024.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +filelock==3.15.3 ; python_version >= "3.8" and python_full_version < "4.0.0" +fsspec==2024.6.0 ; python_version >= "3.8" and python_full_version < "4.0.0" intel-openmp==2021.4.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" jinja2==3.1.4 ; python_version >= "3.8" and python_full_version < "4.0.0" markupsafe==2.1.5 ; python_version >= "3.8" and python_full_version < "4.0.0" @@ -10,6 +10,6 @@ mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0" networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0" sympy==1.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0" -tbb==2021.12.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" -torch==2.3.0+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" -typing-extensions==4.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0" +tbb==2021.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" +torch==2.3.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" +typing-extensions==4.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0" diff --git a/tests/constants.py b/tests/constants.py index eee2ecb1..b8c4f5e1 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -49,6 +49,7 @@ Fromage, GaLore, Gravity, + GrokFastAdamW, Lamb, Lion, Nero, @@ -129,6 +130,7 @@ 'bsam', 'schedulefreeadamw', 'fadam', + 'grokfastadamw', ] VALID_LR_SCHEDULER_NAMES: List[str] = [ @@ -448,6 +450,7 @@ (ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5), (ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5), (FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5), + (GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10), ] ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [ (AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10), diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index d5646b44..26a76d64 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): def test_get_supported_optimizers(): - assert len(get_supported_optimizers()) == 67 + assert len(get_supported_optimizers()) == 68 def test_get_supported_lr_schedulers(): diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index bce91239..e0abec4e 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -13,6 +13,8 @@ Lookahead, PCGrad, ProportionScheduler, + gradfilter_ema, + gradfilter_ma, load_optimizer, ) from pytorch_optimizer.base.exception import NoClosureError, ZeroParameterSizeError @@ -608,3 +610,33 @@ def test_schedule_free_train_mode(): opt.reset() opt.eval() opt.train() + + +@pytest.mark.parametrize('filter_type', ['mean', 'sum']) +def test_grokfast_ma(filter_type, environment): + _, model, _ = environment + + model.fc1.weight.grad = torch.randn(2, 2) + model.fc1.bias.grad = torch.randn(2) + model.fc2.weight.grad = torch.randn(1, 2) + model.fc2.bias.grad = torch.randn(1) + + _ = gradfilter_ma(model, None, window_size=1, filter_type=filter_type, warmup=False) + + +def test_grokfast_ma_invalid(environment): + _, model, _ = environment + + with pytest.raises(ValueError): + _ = gradfilter_ma(model, None, window_size=1, filter_type='asdf', warmup=False) + + +def test_grokfast_ema(environment): + _, model, _ = environment + + model.fc1.weight.grad = torch.randn(2, 2) + model.fc1.bias.grad = torch.randn(2) + model.fc2.weight.grad = torch.randn(1, 2) + model.fc2.bias.grad = torch.randn(1) + + _ = gradfilter_ema(model, None)