-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for different src and dst datatypes in the SYCL implementation for pooling #1878
Conversation
tests/benchdnn/pool/ref_pool.cpp
Outdated
@@ -38,6 +38,9 @@ void compute_ref_fwd(const prb_t *prb, const args_t &args) { | |||
// XXX: this is a hack to let tests with padded area to pass for bf16 | |||
// dt due to the library initialize values with -max_dt, but not -INF. | |||
float max_value = lowest_dt(prb->dst_dt()); | |||
if (prb->src_dt() == dnnl_u8 || prb->src_dt() == dnnl_s8 | |||
|| prb->src_dt() == dnnl_f16 || prb->src_dt() == dnnl_f32) | |||
max_value = lowest_dt(prb->src_dt()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate, please, why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment. Line 40 is a previous hack that was there to deal with bf16
. Because of that hack enabling different src/dst
datatypes failed, the max_value
was taking into account the dst
datatype instead of the src
. When src
or dst
uses different datatypes there are cases where max_value
takes negative values even when src_dt=u8
(just to mention an example). Checking the src_dt
instead of dst_dt
allows us to be sure that we initialize the max_value
variable correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Different data types are supported for other backends and yet so far everything worked beside that issue.
Please post benchdnn command and output without this patch to benchdnn where it fails. 2-3 would be enough. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output I was having when enabling the new src/dst
combinations are:
./tests/benchdnn/benchdnn --pool --engine=gpu --dir=FWD_I --dt=u8:s8 --tag=axb ic35iw13ow14kw3pw3
[ 0][DST][0:0:0] exp_f32: -128 exp: -128 got: 0 diff: 128 rdiff: 1
[ 14][DST][0:1:0] exp_f32: -128 exp: -128 got: 0 diff: 128 rdiff: 1
[ 28][DST][0:2:0] exp_f32: -128 exp: -128 got: 0 diff: 128 rdiff: 1
[ 42][DST][0:3:0] exp_f32: -128 exp: -128 got: 0 diff: 128 rdiff: 1
[ 56][DST][0:4:0] exp_f32: -128 exp: -128 got: 0 diff: 128 rdiff: 1
[ 70][DST][0:5:0] exp_f32: -128 exp: -128 got: 0 diff: 128 rdiff: 1
[ 84][DST][0:6:0] exp_f32: -128 exp: -128 got: 0 diff: 128 rdiff: 1
[ 98][DST][0:7:0] exp_f32: -128 exp: -128 got: 0 diff: 128 rdiff: 1
[ 112][DST][0:8:0] exp_f32: -128 exp: -128 got: 0 diff: 128 rdiff: 1
[ 126][DST][0:9:0] exp_f32: -128 exp: -128 got: 0 diff: 128 rdiff: 1
[COMPARE_STATS][DST]: trh=0 err_max_diff: 128 err_max_rdiff: 1 all_max_diff: 128 all_max_rdiff: 1
0:FAILED (errors:70 total:980) __REPRO: --pool --engine=gpu --dir=FWD_I --dt=u8:s8 --tag=axb ic35iw13ow14kw3pw3
tests:1 passed:0 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:1 listed:0
total: 1.63s; fill: 0.44s (27%); compute_ref: 0.00s (0%); compare: 0.02s (1%);
TBB Warning: Leaked 2 observer_proxy objects
./tests/benchdnn/benchdnn --pool --engine=gpu --dir=FWD_I --dt=f16:u8 --tag=axb --attr-post-ops=add:f32:per_oc ic32ih32iw13oh17ow14kh4kw6sh2sw1ph4pw6
[ 0][DST][0:0:0:0] exp_f32: 13.5 exp: 14 got: 0 diff: 14 rdiff: 1
[ 1][DST][0:0:0:1] exp_f32: 13.5 exp: 14 got: 0 diff: 14 rdiff: 1
[ 2][DST][0:0:0:2] exp_f32: 13.5 exp: 14 got: 0 diff: 14 rdiff: 1
[ 3][DST][0:0:0:3] exp_f32: 13.5 exp: 14 got: 0 diff: 14 rdiff: 1
[ 4][DST][0:0:0:4] exp_f32: 13.5 exp: 14 got: 0 diff: 14 rdiff: 1
[ 5][DST][0:0:0:5] exp_f32: 13.5 exp: 14 got: 0 diff: 14 rdiff: 1
[ 6][DST][0:0:0:6] exp_f32: 13.5 exp: 14 got: 0 diff: 14 rdiff: 1
[ 7][DST][0:0:0:7] exp_f32: 13.5 exp: 14 got: 0 diff: 14 rdiff: 1
[ 8][DST][0:0:0:8] exp_f32: 13.5 exp: 14 got: 0 diff: 14 rdiff: 1
[ 9][DST][0:0:0:9] exp_f32: 13.5 exp: 14 got: 0 diff: 14 rdiff: 1
[COMPARE_STATS][DST]: trh=0 err_max_diff: 16 err_max_rdiff: 1 all_max_diff: 16 all_max_rdiff: 1
0:FAILED (errors:1998 total:15232) __REPRO: --pool --engine=gpu --dir=FWD_I --dt=f16:u8 --tag=axb --attr-post-ops=add:f32:per_oc ic32ih32iw13oh17ow14kh4kw6sh2sw1ph4pw6
tests:1 passed:0 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:1 listed:0
total: 1.35s; fill: 0.17s (12%); compute_ref: 0.00s (0%); compare: 0.03s (2%);
TBB Warning: Leaked 2 observer_proxy objects
./tests/benchdnn/benchdnn --pool --engine=gpu --dir=FWD_I --dt=s8:f32 --tag=axb ic35iw13ow14kw3pw3
[ 0][DST][0:0:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got: -128 diff:3.40282e+38 rdiff: 1
[ 14][DST][0:1:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got: -128 diff:3.40282e+38 rdiff: 1
[ 28][DST][0:2:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got: -128 diff:3.40282e+38 rdiff: 1
[ 42][DST][0:3:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got: -128 diff:3.40282e+38 rdiff: 1
[ 56][DST][0:4:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got: -128 diff:3.40282e+38 rdiff: 1
[ 70][DST][0:5:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got: -128 diff:3.40282e+38 rdiff: 1
[ 84][DST][0:6:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got: -128 diff:3.40282e+38 rdiff: 1
[ 98][DST][0:7:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got: -128 diff:3.40282e+38 rdiff: 1
[ 112][DST][0:8:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got: -128 diff:3.40282e+38 rdiff: 1
[ 126][DST][0:9:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got: -128 diff:3.40282e+38 rdiff: 1
[COMPARE_STATS][DST]: trh=0 err_max_diff:3.40282e+38 err_max_rdiff: 1 all_max_diff:3.40282e+38 all_max_rdiff: 1
0:FAILED (errors:70 total:980) __REPRO: --pool --engine=gpu --dir=FWD_I --dt=s8:f32 --tag=axb ic35iw13ow14kw3pw3
tests:1 passed:0 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:1 listed:0
total: 1.25s; fill: 0.06s (5%); compute_ref: 0.00s (0%); compare: 0.08s (7%);
TBB Warning: Leaked 2 observer_proxy objects
What I think is that the values initialized in the tests/benchdnn/pool/ref_pool.cpp source file were breaking the tests. Let me know what do you think about it. Thanks for your help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kala855, thanks for these examples. It seems there's no clear option how to proceed with it, and it will require clarification with users on what's the expected behavior here. Two possible situations:
- In the sense that there's no quantized operations defined in framework graphs, all operations are considered as f32/bf16. It means that data would be converted to f32/bf16, computed, and then converted to output data type separately. This is aligned with current benchdnn expectations.
- The quantized operation uses src_dt as a data type to update values for max pooling. Then converts it into f32 to apply post-ops, if any, and then converts to dst_dt. This is aligned with the new change and breaking the expectations of the previous one.
A request here is to postpone this PR until we get clarifications from frameworks and proceed with the most reasonable option. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd expect benchdnn change to be dropped from the final version with limited int8 support...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dzarukin I did take into account the suggestions made. However, with limited int8 support, this part of the code still needs to be updated. Because we accept different src/dst datatypes we need to initialize the max pooling algorithm differently for Nvidia GPU. I included a conditional here to initialize the data using prb->src_dt()
.
src/gpu/sycl/ref_pooling.hpp
Outdated
&& (!utils::one_of(f64, src_md(0)->data_type) | ||
&& !utils::one_of(f64, dst_md(0)->data_type)) | ||
&& (IMPLICATION(utils::one_of(bf16, src_md(0)->data_type), | ||
utils::one_of(bf16, dst_md(0)->data_type))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain, please, what makes bf16 so special?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this instance, we're adhering to the specifications outlined in the table available here. Notably, the bf16
scenario stands out as one case requiring identical source and destination data types. Thanks for the comment. Let me know if something else needs to be clarified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a strong reason to limit the support if it can work as expected.
@dzarukin We take into account what was discussed during the meeting. The PR now avoids applying pooling src:dst (u8:u8, s8:s8). |
Could you clarify what issue does this fix? |
src/gpu/sycl/ref_pooling.hpp
Outdated
&& f64 != src_md(0)->data_type | ||
&& f64 != dst_md(0)->data_type | ||
&& (IMPLICATION(bf16 == src_md(0)->data_type, | ||
bf16 == dst_md(0)->data_type)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is redundant with line 56-57.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was fixed in the current commit. Thanks.
src/gpu/sycl/ref_pooling.hpp
Outdated
dst_md(0)->data_type != u8)) | ||
&& (IMPLICATION( | ||
src_md(0)->data_type != dst_md(0)->data_type, | ||
desc_.prop_kind == forward_inference)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand the logic here. You are disabling src and dst datatype to be the same except for bf16.
Having the same datatype for src and dst should be valid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this table src/dst datatypes could be different just when forward_inference
is used. The bf16 dt is not combined with any other dt. Maybe I am misunderstanding something in the table. Let me know if this clarifies a little bit the questions. Thanks.
c1c7999
to
d143eec
Compare
@mgouicem We received feedback mentioning that we need to avoid the use of |
u8:s8 and s8:u8 have semantic problems. |
@dzarukin just to be sure ... The ones we will need to avoid are u8:s8 and s8:u8 ... However, u8:u8 and s8:s8 are allowed? |
Yes, those are allowed and should be supported. There's no ambiguity in such combinations - same init value is used on both ends. |
d143eec
to
f017483
Compare
The implementation now does what is suggested. Let me know if something else is missing @dzarukin . Thank you. |
tests/benchdnn/pool/ref_pool.cpp
Outdated
@@ -38,6 +38,9 @@ void compute_ref_fwd(const prb_t *prb, const args_t &args) { | |||
// XXX: this is a hack to let tests with padded area to pass for bf16 | |||
// dt due to the library initialize values with -max_dt, but not -INF. | |||
float max_value = lowest_dt(prb->dst_dt()); | |||
if (prb->src_dt() == dnnl_u8 || prb->src_dt() == dnnl_s8 | |||
|| prb->src_dt() == dnnl_f16 || prb->src_dt() == dnnl_f32) | |||
max_value = lowest_dt(prb->src_dt()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd expect benchdnn change to be dropped from the final version with limited int8 support...
src/gpu/generic/sycl/ref_pooling.hpp
Outdated
&& (!utils::one_of(f64, src_md(0)->data_type) | ||
&& !utils::one_of(f64, dst_md(0)->data_type)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
&& (!utils::one_of(f64, src_md(0)->data_type) | |
&& !utils::one_of(f64, dst_md(0)->data_type)) | |
&& (!utils::one_of(f64, src_md(0)->data_type, dst_md(0)->data_type)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change added.
src/gpu/generic/sycl/ref_pooling.hpp
Outdated
dst_md(0)->data_type != s8)) | ||
&& (IMPLICATION( | ||
src_md(0)->data_type != dst_md(0)->data_type, | ||
desc_.prop_kind == forward_inference)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
desc_.prop_kind == forward_inference)) | |
desc()->prop_kind == forward_inference)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change added.
src/gpu/generic/sycl/ref_pooling.hpp
Outdated
&& (IMPLICATION(utils::one_of(bf16, src_md(0)->data_type), | ||
utils::one_of(bf16, dst_md(0)->data_type))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd stick to plain comparison.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of the changes were included. Let me know if something else is needed. Thank you.
f017483
to
81aa661
Compare
81aa661
to
9879e77
Compare
Description
This pull request facilitates the utilization of distinct source (src) and destination (dst) datatypes during the forward inference pass within the SYCL pooling operator implementation.