-
Notifications
You must be signed in to change notification settings - Fork 990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rfc: proposal for block level APIs #1852
base: rfcs
Are you sure you want to change the base?
Changes from 1 commit
a62ff98
4fcd3ea
e48e212
4027300
a380114
6dce755
2de75d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ C = \beta C + \alpha \sum_i A_i \cdot B_i + bias | |
with | ||
- $A_i$ a set of matrices of dimension $M \times K$ | ||
- $B_i$ a set of matrices of dimension $K \times N$ | ||
- D and C matrices of dimension $M \times N$ | ||
- C matrix of dimension $M \times N$ | ||
- bias a vector of dimension $N$. | ||
|
||
This proposal discusses exposing these sequential, basic building | ||
|
@@ -42,7 +42,7 @@ for this block level API, we will have are three considerations: | |
### Arbitrary strides between A and B blocks | ||
|
||
To make the API flexible, we want to allow arbitrary strides between A | ||
and B blocks as showed in the picture: | ||
and B blocks as shown in the picture: | ||
![](brgemm_pic.png) | ||
|
||
This is necessary on multiple occasions, a few examples being: | ||
|
@@ -217,6 +217,8 @@ in block level APIs. However, `dnnl::set_max_cpu_isa` will still be | |
effective, and can be used by the user to control the maximum isa | ||
level used by block level APIs. | ||
|
||
If more granularity is needed by end-user, we can add cpu isa selection | ||
as an attribute later. | ||
|
||
### Handling of architectural state | ||
|
||
|
@@ -254,6 +256,8 @@ There are two options here: | |
|
||
The recommendation is to go with the first option, with explicit set | ||
and release functions, that the user can hoist as they see fit. | ||
We will also have a `brgemm::reset_hw_context()` to avoid sequences | ||
where `release_hw_context()` and `set_hw_context()` are called back-to-back. | ||
|
||
With respect to OS syscall, we recommend to make it transparent to | ||
user, by making those upon first created brgemm object that would need | ||
|
@@ -348,20 +352,20 @@ here are a few ways to mitigate the lack of kernel-level cache: | |
|
||
## Transforms and transpose | ||
|
||
Transormation rountines, for example to pack data in a certain layout, | ||
Transformation routines, for example to pack data in a certain layout, | ||
are typically hard to implement without using intrinsics or assembly. | ||
To facilitate packing, we will expose an out-of-place transform | ||
functionality. | ||
|
||
mgouicem marked this conversation as resolved.
Show resolved
Hide resolved
|
||
## All-in-all | ||
|
||
### Interface proposal | ||
### Interface proposal (C++) | ||
|
||
```c++ | ||
|
||
// namespace name to be defined, leaving it general enough for additional block level APIs | ||
namespace dnnl { | ||
namespace block { | ||
namespace ukernel { | ||
|
||
enum packing_tag { | ||
packed_32; | ||
|
@@ -370,71 +374,130 @@ plain; | |
transposed; | ||
} | ||
|
||
struct brgemm_attr { | ||
brgemm_attr(); | ||
set_scales(int a_scale_mask, int a_scale_mask, int a_scale_mask); | ||
set_postops(post_ops &po); | ||
} | ||
|
||
struct attr_params{ | ||
attr_params(); | ||
set_scales(void *a_scale, void *b_scale, void *c_scale); | ||
set_po_args(void **po_args); // array of pointers for post-op arguments, in order in which they were appended in attributes | ||
} | ||
|
||
struct brgemm { | ||
struct desc { | ||
// Vanilla version of brgemm with no post-op or destination conversion. | ||
desc(dim_t batch, dim_t M, dim_t N, dim_t K, | ||
data_type dtA, dim_t ldA, | ||
data_type dtB, dim_t ldB, | ||
data_type dtC, dim_t ldC, | ||
float alpha, float beta); | ||
|
||
// Advanced version with postop and datatype conversion when D type | ||
// is different than C type. | ||
desc(dim_t batch, dim_t M, dim_t N, dim_t K, | ||
data_type dtA, dim_t ldA, | ||
data_type dtB, dim_t ldB, | ||
data_type dtC, dim_t ldC, | ||
data_type dtD, dim_t ldD, | ||
float alpha, float beta, | ||
const brgemm_attr &attr); | ||
|
||
// Queries for expected layouts and temporary memory | ||
packing_tag get_A_tag() const; | ||
packing_tag get_B_tag() const; | ||
size_t get_scratchpad_size() const; | ||
} | ||
|
||
brgemm(const desc &bd); | ||
|
||
// Advanced version with postop and datatype conversion when D type | ||
// is different than C type. | ||
brgemm(dim_t batch, dim_t M, dim_t N, dim_t K, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given dtA and dtB both dt::u8, what dtC and dtD (if with post ops) are supported for now? |
||
data_type dtA, dim_t ldA, | ||
data_type dtB, dim_t ldB, | ||
data_type dtC, dim_t ldC, | ||
data_type dtD, dim_t ldD, | ||
float alpha, float beta, | ||
const brgemm_attr &attr); | ||
|
||
// Vanilla version of brgemm with no post-op or destination conversion. | ||
brgemm(dim_t batch, dim_t M, dim_t N, dim_t K, | ||
data_type dtA, dim_t ldA, | ||
data_type dtB, dim_t ldB, | ||
data_type dtC, dim_t ldC, | ||
float alpha, float beta); | ||
|
||
// Queries for expected layouts and temporary memory | ||
packing_tag get_A_tag() const; // Not really needed, just for consistency | ||
packing_tag get_B_tag() const; | ||
size_t get_scratchpad_size() const; | ||
|
||
// HW context handling. | ||
// This currently mimics AMX (need to clarify for SME): | ||
// - Release is static | ||
// - No release necessary between different calls to set_hw_context | ||
void set_hw_context() const; | ||
void reset_hw_context() const; | ||
static void release_hw_context() const; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
// Execution function for the vanilla brgemm variant. | ||
// we take pointers to A and B as a vector of pairs, guarantees they are the same size | ||
// The batch size is the size of the vector. | ||
// pointers are void*, datatypes are specified in constructor | ||
// Computes C = \beta C + \alpha \sum_i A_i \cdot B_i | ||
void execute(const std::vector<std::pair<void *, void *>> &A_B, | ||
void *C, void *scratch = nullptr); | ||
// separate kernel generation to allow query without jit overhead | ||
void generate(); | ||
|
||
// Execution function for the advanced brgemm variant | ||
// Here the C matrix is just an input to accumulation | ||
// final result after postop/conversion will be in D | ||
// Computes D = \beta C + \alpha \sum_i A_i \cdot B_i + bias | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so here, D has different dtype from C? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes D is used for post-op and conversion fusion, hence why it has a different datatype. |
||
void execute(const std::vector<std::pair<void *, void *>> A_B, | ||
const void *C, void *D, void *scratch = nullptr, | ||
void **post_ops_args = nullptr); | ||
const attr_params &attr_args); | ||
|
||
// Execution function for the vanilla brgemm variant. | ||
// we take pointers to A and B as a vector of pairs, guarantees they are the same size | ||
// The batch size is the size of the vector. | ||
// pointers are void*, datatypes are specified in constructor | ||
// Computes C = \beta C + \alpha \sum_i A_i \cdot B_i | ||
void execute(const std::vector<std::pair<void *, void *>> &A_B, | ||
void *C, void *scratch = nullptr); | ||
} | ||
|
||
struct transform { | ||
struct desc { | ||
desc(dim_t M, dim_t N, data_type dt, | ||
packing_tag tag_src, dim_t ld_src, | ||
packing_tag tag_dst, dim_t ld_dst); | ||
} | ||
transform(const desc &td); | ||
// both src, dst share same dt, no fused conversion | ||
transform(dim_t M, dim_t N, data_type dt, | ||
packing_tag tag_src, dim_t ld_src, | ||
packing_tag tag_dst, dim_t ld_dst); | ||
generate(); | ||
execute(const void *src, void *dst); | ||
} | ||
|
||
} // namespace block | ||
} // namespace ukernel | ||
} // namespace dnnl | ||
``` | ||
### Simple example for bf16 matmul with f32 accumulation and relu fusion. | ||
|
||
### Interface proposal (C) | ||
|
||
```c++ | ||
|
||
// opaque structure | ||
struct dnnl_ukernel_brgemm_attr_t; | ||
dnnl_brgemm_attr_create(); | ||
dnnl_brgemm_attr_set_scales(int a_scale_mask, int a_scale_mask, int a_scale_mask); | ||
dnnl_brgemm_attr_set_postops(dnnl_post_ops &po); | ||
|
||
// opaque structure | ||
struct dnnl_ukernel_attr_params_t; | ||
dnnl_status_t dnnl_ukernel_attr_params_create(dnnl_ukernel_attr_params_t *ap); | ||
dnnl_status_t dnnl_ukernel_attr_params_set_scales(dnnl_ukernel_attr_params_t ap, void *a_scale, void *b_scale, void *c_scale); | ||
dnnl_status_t dnnl_ukernel_attr_params_set_po_args(dnnl_ukernel_attr_params_t ap, void **po_args); | ||
|
||
// Single creation function, if no post-ops/conversion, D should be set to undef | ||
dnnl_status_t dnnl_ukernel_brgemm_create(dnnl_brgemm_t *brgemm, | ||
dnnl_dim_t batch, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, | ||
dnnl_data_type_t a_dt, dnnl_dim_t lda, | ||
dnnl_data_type_t b_dt, dnnl_dim_t ldb, | ||
dnnl_data_type_t c_dt, dnnl_dim_t ldc, | ||
dnnl_data_type_t d_dt, dnnl_dim_t ldd, | ||
float alpha, float beta, | ||
const_dnnl_ukernel_brgemm_attr_t attr); | ||
|
||
// Queries for expected layouts and temporary memory | ||
dnnl_ukernel_packing_tag dnnl_ukernel_brgemm_get_A_tag(const_dnnl_brgemm_t brg); | ||
dnnl_ukernel_packing_tag dnnl_ukernel_brgemm_get_B_tag(const_dnnl_brgemm_t brg); | ||
dnnl_status_t dnnl_ukernel_brgemm_get_scratchpad_size(const_dnnl_brgemm_t brg, size_t *size); | ||
|
||
// HW context management. Release is independent of brgemm object | ||
dnnl_status_t dnnl_ukernel_brgemm_set_hw_context(const_dnnl_brgemm_t brg); | ||
dnnl_status_t dnnl_ukernel_brgemm_reset_hw_context(const_dnnl_brgemm_t brg); | ||
dnnl_status_t dnnl_ukernel_brgemm_release_hw_context(); | ||
|
||
// separate kernel generation to allow query without jit overhead | ||
dnnl_status_t dnnl_ukernel_brgemm_generate(dnnl_brgemm_t brg); | ||
|
||
// Execution function. Single function here. | ||
// If no post-ops/conversion, C_ptr = D_ptr and attr_arguments=NULL | ||
dnnl_status_t dnnl_brgemm_execute(const_dnnl_brgemm_t brg, | ||
const void **A_B_ptr, void *C_ptr, void *D_ptr, void *scratchpad_ptr, | ||
const_ukernel_attr_params_t attr_arguments); | ||
|
||
``` | ||
|
||
### Simple example for bf16 matmul with f32 accumulation and relu fusion.#include "dnnl_block.h" | ||
|
||
```c++ | ||
int matmul_with_relu(const void *src, const void *weights, void *dst) { | ||
|
@@ -537,4 +600,4 @@ WIP | |
[^5]: [Intel optimization guide](https://www.intel.com/content/www/us/en/content-details/671488/intel-64-and-ia-32-architectures-optimization-reference-manual-volume-1.html) | ||
[^6]: [The indirect convolution algorithm](https://arxiv.org/abs/1907.02129) | ||
[^7]: [XNNpack microkernels](https://github.com/google/XNNPACK/blob/8b30931dba3d4f23f0da035fa5330a45b5ade5bf/doc/microkernel-naming-conventions.md?plain=1#L57) | ||
[^8]: [FBGEMM microkernls](https://arxiv.org/abs/2101.05615) | ||
[^8]: [FBGEMM microkernels](https://arxiv.org/abs/2101.05615) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about the AMX palette buffer management? Is it possible to expose this to user so that user can implement their own
brgemm::set_hw_context()
andbrgemm::release_hw_context()
to avoid those function calls?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the problem of hiding this inside oneDNN?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's related to an optimization opportunity. Based on current design, you must reconfigure hw context if multiple brgemm kernels are used. But if we can manage AMX palette, we don't need to reconfigure hw context if those brgemms use the same AMX palette.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZhennanQin , we would like the API to be as ISA agnostic as possible.
Internally, we will not reconfigure if the same pallette is already loaded on the core.