Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transform] Modify FuseTIR pass to propagate buffer attributes #17075

Merged
merged 1 commit into from
Jun 17, 2024

Conversation

quic-sanirudh
Copy link
Contributor

Arguments of a fused TIR PrimFunc generated from a fused relax function do not retain all the buffer attributes from their original PrimFuncs as the buffers are created from the StructInfo of the Relax vars. This patch collects a mapping of relax vars to its corresponding TIR buffers in a fused relax function and uses that info to propagate its buffer attributes such as axis_separators and storage_scope

@quic-sanirudh
Copy link
Contributor Author

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the fix overall, thank you! It looks like there's some implicit assumptions about the Relax/TIR being consumed, which may not hold in general.

src/relax/transform/fuse_tir.cc Show resolved Hide resolved
/*! \brief The IRModule */
const IRModule& mod_;
// size_t call_num_inputs_ = -1;
Map<Var, tir::Buffer> relax_to_tir_var_map_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This data structure assumes that there is a 1:1 mapping from relax::Var to tir::Buffer across the entire fused function. This would have incorrect results for cases where the same tensor is used as multiple arguments (e.g. R.add(A, A)), or where the same tensor is used as an argument to more than one function (e.g. The tensor A corresponds to two different TIR buffers in the sequence mean = R.mean(A); norm = R.sqrt(mean); A_norm = R.divide(A, norm)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I did think about this issue, but I assumed that even though the same relax var might map to different buffers, it should have the same buffer attributes (since its source is the same relax var). I've also added a validation ICHECK to verify that the buffer attributes match (using structural equal).

I've also added a testcase to verify this use case as suggested in the below comment.

src/relax/transform/fuse_tir.cc Outdated Show resolved Hide resolved
const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
if (i < num_inputs) {
const auto& relax_var = Downcast<Var>(relax_args[i]);
relax_to_tir_var_map_.Set(relax_var, buffer_map[tir_var]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buffer_map does not necessarily contain an entry for tir_var. For example, the relax_var could have PrimStructInfo to pass a primitive scalar to the TIR funciton. Even if relax_var has TensorStructInfo, the TIR function may treat the DLTensor* as an opaque pointer, passing it to a PackedFunc without having an entry in the buffer_map.

The best way to handle these cases is to wrap this line in a if(auto tir_buffer = buffer_map.Get(tir_var)) conditional, and then use tir_buffer.value() inside the conditional instead of buffer_map[tir_var].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch. I did not consider this case. Fixed.

src/relax/transform/fuse_tir.cc Outdated Show resolved Hide resolved
for (size_t i = 0; i < tir_args.size(); ++i) {
const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
if (i < num_inputs) {
const auto& relax_var = Downcast<Var>(relax_args[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Downcast<Var> is not guaranteed to work. While the normalizer will pull most relax.Var instances out to their own variable binding, R.const arguments may still appear inline.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, yes constants are possible. I've updated output as a map from Expr to Buffer instead of Var.

src/relax/transform/fuse_tir.cc Outdated Show resolved Hide resolved
src/relax/transform/fuse_tir.cc Outdated Show resolved Hide resolved
cls = Before
with R.dataflow():
w = R.call_tir(
cls.add, [x, y], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test case for incompatible usage of a single Relax var? As currently written, we could have a single Relax variable that is used in two separate R.call_tir statements, where the function being called imposes different restrictions on it. For example, if x were used in cls.add1, which requires axis_separators=[1], and cls.add2, which requires axis_separators=[]. We should be able to identify this case and raise an error when it occurs.

(Ideally, that should never happen, but this would be the last point at which we'd have enough information to catch this failure mode at compile-time.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the test case to check possible inconsistencies.

 Arguments of a fused TIR PrimFunc generated from a fused relax function do not retain all the buffer attributes from their original PrimFuncs as the buffers are created from the StructInfo of the Relax vars. This patch collects a mapping of relax vars to its corresponding TIR buffers in a fused relax function and uses that info to propagate its buffer attributes such as `axis_separators` and `storage_scope`
@quic-sanirudh
Copy link
Contributor Author

@Lunderberg I think I've addressed your comments. When you get a chance, could you please take a look?

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making the changes, and looks good!

@quic-sanirudh quic-sanirudh merged commit 5bfca2e into apache:main Jun 17, 2024
28 checks passed
@quic-sanirudh
Copy link
Contributor Author

Thanks for taking the time to review and provide feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants