-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[FEATURE] Add query_keys transformer version without split #21115
Conversation
Hey @agrabows , Thanks for submitting the PR
CI supported jobs: [unix-cpu, unix-gpu, website, windows-cpu, centos-gpu, edge, sanity, centos-cpu, windows-gpu, miscellaneous, clang] Note: |
5c11296
to
2495ae9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
input_names.emplace_back("min_q"); | ||
input_names.emplace_back("max_q"); | ||
} | ||
input_names.emplace_back("keys"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"key" input name should be before "min_q" and "max_q"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
float* output_min_0 = outputs[2].data().dptr<float>(); | ||
float* output_max_0 = outputs[3].data().dptr<float>(); | ||
float* output_min_1 = outputs[4].data().dptr<float>(); | ||
float* output_max_1 = outputs[5].data().dptr<float>(); | ||
*output_min_0 = min_output_; | ||
*output_max_0 = max_output_; | ||
*output_min_1 = min_output_; | ||
*output_max_1 = max_output_; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are only 3 output tensors. Doesn't it throw exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it didn't but you are right, this code was incorrect
@@ -29,54 +29,78 @@ | |||
#include "operator/subgraph/common.h" | |||
#include "dnnl_transformer-inl.h" | |||
|
|||
// 3 tensors within one (queries key values) = | |||
// 3 tensors within one (queries key values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// 3 tensors within one (queries key values) | |
// 3 tensors within one (queries keys values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
*status = kFirstSwapAx; | ||
matched_list->push_back(&input_node); | ||
} | ||
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this return true be in 'if' statement below matched_list->push_back(&input_node); ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@mxnet-bot run ci [unix-gpu] |
Jenkins CI successfully triggered : [unix-gpu] |
Description
MXNet is fusing split, reshape, swapaxis and batch_dot operators for performance purpose. In gpt-2 model this fuse could be done as well if we exclude split.
![image](https://user-images.githubusercontent.com/59651240/183493340-589cdb0a-bb01-4fd8-8bbd-b9c295de3e77.png)
![image](https://user-images.githubusercontent.com/59651240/183493199-e1c9f506-53b4-4ee7-97f2-89183a40233b.png)
->
Checklist
Essentials