Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Add JSON Database Validation Scripts #12948

Merged
merged 12 commits into from
Nov 9, 2022

Conversation

zxybazh
Copy link
Member

@zxybazh zxybazh commented Sep 29, 2022

This PR introduces a validation script to check result accuracy between the scheduled IRModules and original IRModule stored in MetaSchedule database. The validate function could also be reused for other type of databases in MetaSchedule. The result would be printed out on the screen as validation passed or failed at some records.

CC @junrushao

When running checks the expected output looks like:

Progress      1 /   1003 checked.
Progress      2 /   1003 checked.
Progress      3 /   1003 checked.
Progress      4 /   1003 checked.
Progress      5 /   1003 checked.
...

If everything runs well, the script will print out a Validation passed! in the end.
If there's any unexpected error or unmatched results, it will print out the IRModules, inputs or exceptions.

Progress     74 /   1003 checked.
Validation failed!

Original Result:
------------------------------
[array([[[[0.5871589 , 0.47408924, 0.60046124, ..., 0.12766571,
          0.3109562 , 0.663565  ],
         [0.14416751, 0.41065693, 0.86465174, ..., 0.67013186,
          0.93802315, 0.9002784 ],
         ...

Scheduled Result:
------------------------------
[array([[[[0.5871589 , 0.47408924, 0.60046124, ..., 0.12766571,
          0.3109562 , 0.663565  ],
         [0.14416751, 0.41065693, 0.86465174, ..., 0.67013186,
          0.93802315, 0.9002784 ],
         [0.32672247, 0.51320565, 0.07578897, ..., 0.8171424 ,
          0.835486  , 0.26994228],
         ...,

Input:
------------------------------
[array([[[[0.5871589 , 0.47408924, 0.60046124, ..., 0.12766571,
          ...

Original IRModule:
------------------------------
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(p0: T.Buffer[(1, 7, 7, 512), "float32"], tensor: T.Buffer[(1, 1, 1, 512), "float32"]) -> None:
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "main"})
        # body
        # with T.block("root")
        tensor_1 = T.alloc_buffer([1, 1, 1, 512], dtype="float32")
        for i0, i1, i2, i3, i4, i5 in T.grid(1, 1, 1, 512, 7, 7):
            with T.block("tensor"):
                ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
                T.reads(p0[ax0, ax1 * 7 + rv0, ax2 * 7 + rv1, ax3])
                T.writes(tensor_1[ax0, ax1, ax2, ax3])
                with T.init():
                    tensor_1[ax0, ax1, ax2, ax3] = T.float32(0)
                tensor_1[ax0, ax1, ax2, ax3] = tensor_1[ax0, ax1, ax2, ax3] + p0[ax0, ax1 * 7 + rv0, ax2 * 7 + rv1, ax3]
        for i0, i1, i2, i3 in T.grid(1, 1, 1, 512):
            with T.block("tensor_1"):
                ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(tensor_1[ax0, ax1, ax2, ax3])
                T.writes(tensor[ax0, ax1, ax2, ax3])
                tensor[ax0, ax1, ax2, ax3] = tensor_1[ax0, ax1, ax2, ax3] * T.float32(0.020408163265306121)
    


Scheduled IRModule:
------------------------------
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(p0: T.Buffer[(1, 7, 7, 512), "float32"], tensor: T.Buffer[(1, 1, 1, 512), "float32"]) -> None:
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "main"})
        # body
        # with T.block("root")
        tensor_shared = T.alloc_buffer([1, 1, 1, 512], dtype="float32", scope="shared")
        for i0_i1_i2_i3_0_fused in T.thread_binding(2, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":T.int64(64), "pragma_unroll_explicit":T.int64(1)}):
            for ax0, ax1, ax2, ax3, ax4_ax5_fused_0 in T.grid(1, 1, 1, 256, 1):
                for ax4_ax5_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
                    with T.block("tensor"):
                        T.where(ax4_ax5_fused_0 * 256 + ax4_ax5_fused_1 < 49)
                        ax0_1, ax1_1, ax2_1 = T.axis.remap("SSS", [ax0, ax1, ax2])
                        ax3_1 = T.axis.spatial(512, i0_i1_i2_i3_0_fused * 256 + ax3)
                        rv0 = T.axis.reduce(7, (ax4_ax5_fused_0 * 256 + ax4_ax5_fused_1) // 7)
                        rv1 = T.axis.reduce(7, (ax4_ax5_fused_0 * 256 + ax4_ax5_fused_1) % 7)
                        T.reads(p0[ax0_1, ax1_1 * 7 + rv0, ax2_1 * 7 + rv1, ax3_1])
                        T.writes(tensor_shared[ax0_1, ax1_1, ax2_1, ax3_1])
                        with T.init():
                            tensor_shared[ax0_1, ax1_1, ax2_1, ax3_1] = T.float32(0)
                        tensor_shared[ax0_1, ax1_1, ax2_1, ax3_1] = tensor_shared[ax0_1, ax1_1, ax2_1, ax3_1] + p0[ax0_1, ax1_1 * 7 + rv0, ax2_1 * 7 + rv1, ax3_1]
            for i3_1 in T.thread_binding(256, thread="threadIdx.x"):
                with T.block("tensor_1"):
                    ax0 = T.axis.spatial(1, 0)
                    ax1 = T.axis.spatial(1, 0)
                    ax2 = T.axis.spatial(1, 0)
                    ax3 = T.axis.spatial(512, i0_i1_i2_i3_0_fused * 256 + i3_1)
                    T.reads(tensor_shared[ax0, ax1, ax2, ax3])
                    T.writes(tensor[ax0, ax1, ax2, ax3])
                    tensor[ax0, ax1, ax2, ax3] = tensor_shared[ax0, ax1, ax2, ax3] * T.float32(0.020408163265306121)

@zxybazh zxybazh marked this pull request as ready for review October 3, 2022 23:00
Copy link
Contributor

@shingjan shingjan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just some nits

@zxybazh zxybazh force-pushed the feature/2022-09-29/database-validation branch from 32cf60e to c531e50 Compare October 11, 2022 23:21
@tvm-bot
Copy link
Collaborator

tvm-bot commented Oct 12, 2022

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@areusch areusch added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it and removed needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it labels Oct 19, 2022
Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM

@zxybazh zxybazh force-pushed the feature/2022-09-29/database-validation branch from de13c2b to b0b28cb Compare November 9, 2022 19:27
@zxybazh zxybazh merged commit 5dc4186 into apache:main Nov 9, 2022
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 10, 2022
* Add validation scripts.

* Fix testing script.

* Fix lint.

* Fix lint.

* Fix inputs.

* Fix lint.

* Fix lint.

* Add timer func.

* Fix ci.

* Address comments.

* Add total time statistics.

* Fix lint.
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
* Add validation scripts.

* Fix testing script.

* Fix lint.

* Fix lint.

* Fix inputs.

* Fix lint.

* Fix lint.

* Add timer func.

* Fix ci.

* Address comments.

* Add total time statistics.

* Fix lint.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants