Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor code for sum Triton kernels #2303

Closed
wants to merge 2 commits into from

Conversation

jananisriram
Copy link
Contributor

Summary:
Refactor code to improve readability and logical flow for cases which select the sum Triton kernel implementation to run. Create helper functions for the following cases:

  • Reduce N-dimensional input to scalar output
  • Reduce 2-dimensional input to 1-dimensional output
  • Reduce 3-dimensional input along dimension 1 to 2-dimensional output

Add command line argument parsing for the input_dim parameter, which specifies the number of dimensions desired in kernel inputs.

Modify absolute tolerance to account for floating-point operation error.

Reviewed By: jbschlosser

Differential Revision: D58488137

jananisriram and others added 2 commits June 14, 2024 14:17
Summary:
Extend support for reducing across individual dimensions on 2-dimensional matrices by allowing for varying block sizes on both the `M` (first) and `N` (second) dimensions.

The existing kernel performed a simplified reduction, assuming that the entire reduction dimension fit within one thread block. The new kernel implementation removes the need for this assumption, allowing both the reduction and the non-reduction dimensions to fit in multiple thread blocks. This implementation also enables autotuning on block sizes for both the `M` and `N` dimensions.

For 1D results, add a `sum_then_add` configuration which decides which kernel configuration to run. `Sum_then_add` sums individual blocks of input and adds these sums into a buffer. `Add_then_sum` adds blocks of raw input into a buffer, reducing the buffer last.

Differential Revision: D58313958

Reviewed By: davidberard98
Summary:
Refactor code to improve readability and logical flow for cases which select the `sum` Triton kernel implementation to run. Create helper functions for the following cases:
- Reduce N-dimensional input to scalar output
- Reduce 2-dimensional input to 1-dimensional output
- Reduce 3-dimensional input along dimension 1 to 2-dimensional output

Add command line argument parsing for the `input_dim` parameter, which specifies the number of dimensions desired in kernel inputs.

Modify absolute tolerance to account for floating-point operation error.

Reviewed By: jbschlosser

Differential Revision: D58488137
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58488137

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 339ccfd.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants