Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend autograd support #30

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

askorikov
Copy link
Contributor

This PR adds support for forward-mode and higher-order differentiation in PyTorch autograd engine. Additionally, it enables integration with torch.func, which implements JAX-style composable functional transforms in PyTorch, allowing for a flexible and elegant use of different types of automatic differentiation (+vectorizing map functionality also implemented in this PR).

Integration with torch.func breaks compatibility with PyTorch < 2.0, however, so we need to estimate if it's safe to drop support for it at this moment.

To do:

  • Check for use cases requiring PyTorch < 2.0
  • Add tests

With the ambition to enable forward-mode and higher-order differentiation support, checking `input.requires_grad` is not sufficient to determine if we will need the relevant arguments in the future.
WARNING:  this breaks compatibility with PyTorch < 2.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant