Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for different src and dst datatypes in the SYCL implementation for pooling #1878

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kala855
Copy link
Contributor

@kala855 kala855 commented Apr 25, 2024

Description

This pull request facilitates the utilization of distinct source (src) and destination (dst) datatypes during the forward inference pass within the SYCL pooling operator implementation.

@@ -38,6 +38,9 @@ void compute_ref_fwd(const prb_t *prb, const args_t &args) {
// XXX: this is a hack to let tests with padded area to pass for bf16
// dt due to the library initialize values with -max_dt, but not -INF.
float max_value = lowest_dt(prb->dst_dt());
if (prb->src_dt() == dnnl_u8 || prb->src_dt() == dnnl_s8
|| prb->src_dt() == dnnl_f16 || prb->src_dt() == dnnl_f32)
max_value = lowest_dt(prb->src_dt());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate, please, why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. Line 40 is a previous hack that was there to deal with bf16. Because of that hack enabling different src/dst datatypes failed, the max_value was taking into account the dst datatype instead of the src. When src or dst uses different datatypes there are cases where max_value takes negative values even when src_dt=u8 (just to mention an example). Checking the src_dt instead of dst_dt allows us to be sure that we initialize the max_value variable correctly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different data types are supported for other backends and yet so far everything worked beside that issue.
Please post benchdnn command and output without this patch to benchdnn where it fails. 2-3 would be enough. Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output I was having when enabling the new src/dst combinations are:

./tests/benchdnn/benchdnn --pool --engine=gpu --dir=FWD_I --dt=u8:s8 --tag=axb ic35iw13ow14kw3pw3
[   0][DST][0:0:0] exp_f32:        -128 exp:        -128 got:           0 diff:     128 rdiff:       1
[  14][DST][0:1:0] exp_f32:        -128 exp:        -128 got:           0 diff:     128 rdiff:       1
[  28][DST][0:2:0] exp_f32:        -128 exp:        -128 got:           0 diff:     128 rdiff:       1
[  42][DST][0:3:0] exp_f32:        -128 exp:        -128 got:           0 diff:     128 rdiff:       1
[  56][DST][0:4:0] exp_f32:        -128 exp:        -128 got:           0 diff:     128 rdiff:       1
[  70][DST][0:5:0] exp_f32:        -128 exp:        -128 got:           0 diff:     128 rdiff:       1
[  84][DST][0:6:0] exp_f32:        -128 exp:        -128 got:           0 diff:     128 rdiff:       1
[  98][DST][0:7:0] exp_f32:        -128 exp:        -128 got:           0 diff:     128 rdiff:       1
[ 112][DST][0:8:0] exp_f32:        -128 exp:        -128 got:           0 diff:     128 rdiff:       1
[ 126][DST][0:9:0] exp_f32:        -128 exp:        -128 got:           0 diff:     128 rdiff:       1
[COMPARE_STATS][DST]: trh=0 err_max_diff:     128 err_max_rdiff:       1 all_max_diff:     128 all_max_rdiff:       1
0:FAILED (errors:70 total:980) __REPRO: --pool --engine=gpu --dir=FWD_I --dt=u8:s8 --tag=axb ic35iw13ow14kw3pw3
tests:1 passed:0 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:1 listed:0
total: 1.63s; fill: 0.44s (27%); compute_ref: 0.00s (0%); compare: 0.02s (1%);
TBB Warning: Leaked 2 observer_proxy objects
./tests/benchdnn/benchdnn --pool --engine=gpu --dir=FWD_I --dt=f16:u8 --tag=axb --attr-post-ops=add:f32:per_oc ic32ih32iw13oh17ow14kh4kw6sh2sw1ph4pw6
[   0][DST][0:0:0:0] exp_f32:        13.5 exp:          14 got:           0 diff:      14 rdiff:       1
[   1][DST][0:0:0:1] exp_f32:        13.5 exp:          14 got:           0 diff:      14 rdiff:       1
[   2][DST][0:0:0:2] exp_f32:        13.5 exp:          14 got:           0 diff:      14 rdiff:       1
[   3][DST][0:0:0:3] exp_f32:        13.5 exp:          14 got:           0 diff:      14 rdiff:       1
[   4][DST][0:0:0:4] exp_f32:        13.5 exp:          14 got:           0 diff:      14 rdiff:       1
[   5][DST][0:0:0:5] exp_f32:        13.5 exp:          14 got:           0 diff:      14 rdiff:       1
[   6][DST][0:0:0:6] exp_f32:        13.5 exp:          14 got:           0 diff:      14 rdiff:       1
[   7][DST][0:0:0:7] exp_f32:        13.5 exp:          14 got:           0 diff:      14 rdiff:       1
[   8][DST][0:0:0:8] exp_f32:        13.5 exp:          14 got:           0 diff:      14 rdiff:       1
[   9][DST][0:0:0:9] exp_f32:        13.5 exp:          14 got:           0 diff:      14 rdiff:       1
[COMPARE_STATS][DST]: trh=0 err_max_diff:      16 err_max_rdiff:       1 all_max_diff:      16 all_max_rdiff:       1
0:FAILED (errors:1998 total:15232) __REPRO: --pool --engine=gpu --dir=FWD_I --dt=f16:u8 --tag=axb --attr-post-ops=add:f32:per_oc ic32ih32iw13oh17ow14kh4kw6sh2sw1ph4pw6
tests:1 passed:0 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:1 listed:0
total: 1.35s; fill: 0.17s (12%); compute_ref: 0.00s (0%); compare: 0.03s (2%);
TBB Warning: Leaked 2 observer_proxy objects
./tests/benchdnn/benchdnn --pool --engine=gpu --dir=FWD_I --dt=s8:f32 --tag=axb ic35iw13ow14kw3pw3
[   0][DST][0:0:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got:        -128 diff:3.40282e+38 rdiff:       1
[  14][DST][0:1:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got:        -128 diff:3.40282e+38 rdiff:       1
[  28][DST][0:2:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got:        -128 diff:3.40282e+38 rdiff:       1
[  42][DST][0:3:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got:        -128 diff:3.40282e+38 rdiff:       1
[  56][DST][0:4:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got:        -128 diff:3.40282e+38 rdiff:       1
[  70][DST][0:5:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got:        -128 diff:3.40282e+38 rdiff:       1
[  84][DST][0:6:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got:        -128 diff:3.40282e+38 rdiff:       1
[  98][DST][0:7:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got:        -128 diff:3.40282e+38 rdiff:       1
[ 112][DST][0:8:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got:        -128 diff:3.40282e+38 rdiff:       1
[ 126][DST][0:9:0] exp_f32:-3.40282e+38 exp:-3.40282e+38 got:        -128 diff:3.40282e+38 rdiff:       1
[COMPARE_STATS][DST]: trh=0 err_max_diff:3.40282e+38 err_max_rdiff:       1 all_max_diff:3.40282e+38 all_max_rdiff:       1
0:FAILED (errors:70 total:980) __REPRO: --pool --engine=gpu --dir=FWD_I --dt=s8:f32 --tag=axb ic35iw13ow14kw3pw3
tests:1 passed:0 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:1 listed:0
total: 1.25s; fill: 0.06s (5%); compute_ref: 0.00s (0%); compare: 0.08s (7%);
TBB Warning: Leaked 2 observer_proxy objects

What I think is that the values initialized in the tests/benchdnn/pool/ref_pool.cpp source file were breaking the tests. Let me know what do you think about it. Thanks for your help.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kala855, thanks for these examples. It seems there's no clear option how to proceed with it, and it will require clarification with users on what's the expected behavior here. Two possible situations:

  1. In the sense that there's no quantized operations defined in framework graphs, all operations are considered as f32/bf16. It means that data would be converted to f32/bf16, computed, and then converted to output data type separately. This is aligned with current benchdnn expectations.
  2. The quantized operation uses src_dt as a data type to update values for max pooling. Then converts it into f32 to apply post-ops, if any, and then converts to dst_dt. This is aligned with the new change and breaking the expectations of the previous one.

A request here is to postpone this PR until we get clarifications from frameworks and proceed with the most reasonable option. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect benchdnn change to be dropped from the final version with limited int8 support...

Copy link
Contributor Author

@kala855 kala855 Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dzarukin I did take into account the suggestions made. However, with limited int8 support, this part of the code still needs to be updated. Because we accept different src/dst datatypes we need to initialize the max pooling algorithm differently for Nvidia GPU. I included a conditional here to initialize the data using prb->src_dt().

&& (!utils::one_of(f64, src_md(0)->data_type)
&& !utils::one_of(f64, dst_md(0)->data_type))
&& (IMPLICATION(utils::one_of(bf16, src_md(0)->data_type),
utils::one_of(bf16, dst_md(0)->data_type)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain, please, what makes bf16 so special?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this instance, we're adhering to the specifications outlined in the table available here. Notably, the bf16 scenario stands out as one case requiring identical source and destination data types. Thanks for the comment. Let me know if something else needs to be clarified.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a strong reason to limit the support if it can work as expected.

@vpirogov vpirogov added this to the v3.6 milestone May 21, 2024
@kala855
Copy link
Contributor Author

kala855 commented Jun 17, 2024

@dzarukin We take into account what was discussed during the meeting. The PR now avoids applying pooling src:dst (u8:u8, s8:s8).

@kala855 kala855 requested a review from dzarukin June 17, 2024 09:46
@mgouicem
Copy link
Contributor

@dzarukin We take into account what was discussed during the meeting. The PR now avoids applying pooling src:dst (u8:u8, s8:s8).

Could you clarify what issue does this fix?
IIUC, the main problem is in the selection of datatype for the max value computation (see this comment). In the case where both src and dst have same datatype, there should be no ambiguity.

&& f64 != src_md(0)->data_type
&& f64 != dst_md(0)->data_type
&& (IMPLICATION(bf16 == src_md(0)->data_type,
bf16 == dst_md(0)->data_type))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is redundant with line 56-57.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was fixed in the current commit. Thanks.

dst_md(0)->data_type != u8))
&& (IMPLICATION(
src_md(0)->data_type != dst_md(0)->data_type,
desc_.prop_kind == forward_inference))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand the logic here. You are disabling src and dst datatype to be the same except for bf16.
Having the same datatype for src and dst should be valid.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this table src/dst datatypes could be different just when forward_inference is used. The bf16 dt is not combined with any other dt. Maybe I am misunderstanding something in the table. Let me know if this clarifies a little bit the questions. Thanks.

@kala855
Copy link
Contributor Author

kala855 commented Jun 26, 2024

@dzarukin We take into account what was discussed during the meeting. The PR now avoids applying pooling src:dst (u8:u8, s8:s8).

Could you clarify what issue does this fix? IIUC, the main problem is in the selection of datatype for the max value computation (see this comment). In the case where both src and dst have same datatype, there should be no ambiguity.

@mgouicem We received feedback mentioning that we need to avoid the use of u8:u8 and s8:s8. Maybe @dzarukin have any additional clue about it. If this combination is accepted we need to enabled it an run some tests. I am in line with your thoughts about the max value computation.

@dzarukin
Copy link
Contributor

dzarukin commented Jul 1, 2024

@mgouicem We received feedback mentioning that we need to avoid the use of u8:u8 and s8:s8. Maybe @dzarukin have any additional clue about it. If this combination is accepted we need to enabled it an run some tests. I am in line with your thoughts about the max value computation.

u8:s8 and s8:u8 have semantic problems.

@kala855
Copy link
Contributor Author

kala855 commented Jul 3, 2024

@mgouicem We received feedback mentioning that we need to avoid the use of u8:u8 and s8:s8. Maybe @dzarukin have any additional clue about it. If this combination is accepted we need to enabled it an run some tests. I am in line with your thoughts about the max value computation.

u8:s8 and s8:u8 have semantic problems.

@dzarukin just to be sure ... The ones we will need to avoid are u8:s8 and s8:u8 ... However, u8:u8 and s8:s8 are allowed?
The previous just to be sure we are talking about the same and proceed to do the corrections. Thank you.

@kala855 kala855 requested a review from mgouicem July 3, 2024 09:34
@dzarukin
Copy link
Contributor

dzarukin commented Jul 3, 2024

@mgouicem We received feedback mentioning that we need to avoid the use of u8:u8 and s8:s8. Maybe @dzarukin have any additional clue about it. If this combination is accepted we need to enabled it an run some tests. I am in line with your thoughts about the max value computation.

u8:s8 and s8:u8 have semantic problems.

@dzarukin just to be sure ... The ones we will need to avoid are u8:s8 and s8:u8 ... However, u8:u8 and s8:s8 are allowed? The previous just to be sure we are talking about the same and proceed to do the corrections. Thank you.

Yes, those are allowed and should be supported. There's no ambiguity in such combinations - same init value is used on both ends.

@kala855
Copy link
Contributor Author

kala855 commented Jul 10, 2024

@mgouicem We received feedback mentioning that we need to avoid the use of u8:u8 and s8:s8. Maybe @dzarukin have any additional clue about it. If this combination is accepted we need to enabled it an run some tests. I am in line with your thoughts about the max value computation.

u8:s8 and s8:u8 have semantic problems.

@dzarukin just to be sure ... The ones we will need to avoid are u8:s8 and s8:u8 ... However, u8:u8 and s8:s8 are allowed? The previous just to be sure we are talking about the same and proceed to do the corrections. Thank you.

Yes, those are allowed and should be supported. There's no ambiguity in such combinations - same init value is used on both ends.

The implementation now does what is suggested. Let me know if something else is missing @dzarukin . Thank you.

@@ -38,6 +38,9 @@ void compute_ref_fwd(const prb_t *prb, const args_t &args) {
// XXX: this is a hack to let tests with padded area to pass for bf16
// dt due to the library initialize values with -max_dt, but not -INF.
float max_value = lowest_dt(prb->dst_dt());
if (prb->src_dt() == dnnl_u8 || prb->src_dt() == dnnl_s8
|| prb->src_dt() == dnnl_f16 || prb->src_dt() == dnnl_f32)
max_value = lowest_dt(prb->src_dt());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect benchdnn change to be dropped from the final version with limited int8 support...

Comment on lines 56 to 57
&& (!utils::one_of(f64, src_md(0)->data_type)
&& !utils::one_of(f64, dst_md(0)->data_type))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
&& (!utils::one_of(f64, src_md(0)->data_type)
&& !utils::one_of(f64, dst_md(0)->data_type))
&& (!utils::one_of(f64, src_md(0)->data_type, dst_md(0)->data_type))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change added.

dst_md(0)->data_type != s8))
&& (IMPLICATION(
src_md(0)->data_type != dst_md(0)->data_type,
desc_.prop_kind == forward_inference))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
desc_.prop_kind == forward_inference))
desc()->prop_kind == forward_inference))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change added.

Comment on lines 58 to 59
&& (IMPLICATION(utils::one_of(bf16, src_md(0)->data_type),
utils::one_of(bf16, dst_md(0)->data_type)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd stick to plain comparison.

Copy link
Contributor Author

@kala855 kala855 Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the changes were included. Let me know if something else is needed. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants