Skip to content

Commit

Permalink
[Relax] Implement R.ensure_aligned and update memory planning for R.view
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jul 9, 2024
1 parent 0fc047c commit a7556a8
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 77 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relax/op/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
"""Relax memory primitives."""

from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor
from .view import view
from .view import view, ensure_aligned
17 changes: 17 additions & 0 deletions python/tvm/relax/op/memory/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,20 @@ def _normalize(expr, relax_cls):
relative_byte_offset = _normalize(relative_byte_offset, PrimValue)

return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore


def ensure_aligned(data: Expr) -> Expr:
"""
Ensure the tensor has elem_offset == 0. A copy will be made if necessary.
Parameters
----------
data : relax.Expr
The input tensor
Results
-------
result : relax.Expr
The aligned tensor
"""
return _ffi_api.ensure_aligned(data) # type: ignore
20 changes: 20 additions & 0 deletions src/relax/backend/vm/vm_builtin_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class VMBuiltinLowerMutator : public ExprMutator {
return Reshape(call);
} else if (call->op == shape_of_op_) {
return ShapeOf(call);
} else if (call->op == view_op_) {
return View(call);
} else if (call->op == ensure_aligned_op_) {
return EnsureAligned(call);
} else if (call->op == to_vdevice_op_) {
return ToDevice(call);
} else if (call->op == make_closure_op_) {
Expand Down Expand Up @@ -124,6 +128,19 @@ class VMBuiltinLowerMutator : public ExprMutator {
}
}

Expr View(const Call& view_node) {
StructInfoDeriveFunc infer_sinfo_env_func;
infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo");
auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true);
ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo);
return Call(runtime_view_func, view_node->args, view_node->attrs, {runtime_view_sinfo});
}

Expr EnsureAligned(const Call& call_node) {
ICHECK(call_node->args.size() == 1);
return Call(builtin_ensure_aligned_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr ShapeOf(const Call& call_node) {
ICHECK(call_node->args.size() == 1);
ICHECK(call_node->struct_info_.defined());
Expand Down Expand Up @@ -188,6 +205,8 @@ class VMBuiltinLowerMutator : public ExprMutator {
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& reshape_op_ = Op::Get("relax.reshape");
const Op& shape_of_op_ = Op::Get("relax.shape_of");
const Op& view_op_ = Op::Get("relax.memory.view");
const Op& ensure_aligned_op_ = Op::Get("relax.memory.ensure_aligned");
const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
Expand All @@ -208,6 +227,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
const ExternFunc builtin_ensure_aligned_{"vm.builtin.ensure_aligned"};
};

Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); }
Expand Down
34 changes: 28 additions & 6 deletions src/relax/op/memory/view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,12 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) {
relative_byte_offset = relax::PrimValue::Int64(0);
}

StructInfoDeriveFunc infer_sinfo_env_func;
infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo");
auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true);

ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo);
if (shape.same_as(call->args[1]) && dtype.same_as(call->args[2]) &&
relative_byte_offset.same_as(call->args[3])) {
return call;
}

return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset});
return Call(call->op, {data, shape, dtype, relative_byte_offset});
}

TVM_REGISTER_OP("relax.memory.view")
Expand All @@ -355,5 +354,28 @@ TVM_REGISTER_OP("relax.memory.view")
.set_attr<FLegalize>("FLegalize", LegalizeView)
.set_attr<Bool>("FPurity", Bool(true));

Expr ensure_aligned(const Expr& x) {
static const Op& op = Op::Get("relax.memory.ensure_aligned");
return Call(op, {x});
}

TVM_REGISTER_GLOBAL("relax.op.memory.ensure_aligned").set_body_typed(ensure_aligned);

StructInfo InferStructInfoEnsureAligned(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Operator " << call->op << " should receive 1 argument, "
<< "but received " << call->args);
}
return GetStructInfo(call->args[0]);
}

TVM_REGISTER_OP("relax.memory.ensure_aligned")
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<Bool>("RequiresArgumentShapes", Bool(false))
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEnsureAligned)
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
} // namespace tvm
3 changes: 3 additions & 0 deletions src/relax/op/memory/view.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ namespace relax {
/*! \brief View a tensor with different properties. */
Expr view(Expr x, Optional<Expr> shape, Optional<Expr> dtype, Optional<Expr> relative_byte_offset);

/*! \brief Ensure the tensor has elem_offset == 0. A copy will be made if necessary. */
Expr ensure_aligned(const Expr& x);

} // namespace relax
} // namespace tvm

Expand Down
13 changes: 9 additions & 4 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,13 @@ class TokenAllocator1D {
std::vector<StorageToken> full_pool_;
};

/*! \brief Check if the input op is "relax.reshape". */
bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); }
/*! \brief Check if the input op is a memory op that return the same buffer as the input buffer. */
bool IsInplaceMemoryOp(const Expr& op) {
static const Op& reshape_op = Op::Get("relax.reshape");
static const Op& view_op = Op::Get("relax.memory.view");
static const Op& ensure_aligned_op = Op::Get("relax.memory.ensure_aligned");
return op.same_as(reshape_op) || op.same_as(view_op) || op.same_as(ensure_aligned_op);
}

/*! \brief The base class for the storage allocation visitor. */
class StorageAllocatorBaseVisitor : public ExprVisitor {
Expand Down Expand Up @@ -498,7 +503,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
// Create a storage token for builtin alloc_tensor.
this->CreateToken(call);
return;
} else if (IsReshape(call->op)) {
} else if (IsInplaceMemoryOp(call->op)) {
// Reuse the input's token for builtin reshape.
SetTokens(call, GetTokens(call->args[0]));
return;
Expand Down Expand Up @@ -751,7 +756,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor {
block_tokens.push_back(new_token.get());
}
return;
} else if (IsReshape(call->op)) {
} else if (IsInplaceMemoryOp(call->op)) {
Tokens tokens = GetTokens(call->args[0]);
ICHECK(!tokens.IsNested());
if (tokens.IsLeaf()) {
Expand Down
14 changes: 14 additions & 0 deletions src/runtime/relax_vm/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,20 @@ TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data
return ShapeTuple(out_shape);
});

TVM_REGISTER_GLOBAL("vm.builtin.ensure_aligned").set_body_typed([](NDArray data) {
if (data->byte_offset == 0) {
return data;
}
auto device_type = data->device.device_type;
DLManagedTensor* dl_tensor = data.ToDLPack();
dl_tensor->dl_tensor.data =
reinterpret_cast<char*>(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset;
dl_tensor->dl_tensor.byte_offset = 0;
// For platforms that does not support pointer arithmetic, we need to copy the data to a new
// buffer.
return NDArray::FromDLPack(dl_tensor);
});

} // namespace relax_vm
} // namespace runtime
} // namespace tvm
Expand Down
105 changes: 39 additions & 66 deletions tests/python/relax/test_op_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,18 +483,7 @@ def main(A: R.Tensor([4096], "float32")):
class Expected:
@R.function
def main(A: R.Tensor([4096], "float32")):
B = R.ExternFunc(
"runtime.TVMArrayCreateView",
R.Callable(
derive_func="tvm.relax.struct_info.infer_view_sinfo",
purity=True,
),
)(
A,
R.shape([64, 64]),
R.dtype("float32"),
R.prim_value(0),
)
B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0)
return B

After = tvm.relax.transform.LegalizeOps()(Before)
Expand All @@ -515,18 +504,7 @@ def main(A: R.Tensor(dtype="float32")):
class Expected:
@R.function
def main(A: R.Tensor(dtype="float32")):
B = R.ExternFunc(
"runtime.TVMArrayCreateView",
R.Callable(
derive_func="tvm.relax.struct_info.infer_view_sinfo",
purity=True,
),
)(
A,
R.shape([64, 64]),
R.dtype("float32"),
R.prim_value(0),
)
B = R.memory.view(A, shape=R.shape([64, 64]), dtype="float32", relative_byte_offset=0)
return B

After = tvm.relax.transform.LegalizeOps()(Before)
Expand All @@ -545,17 +523,8 @@ def main(A: R.Tensor([4096], "float32")):
class Expected:
@R.function
def main(A: R.Tensor([4096], "float32")):
B = R.ExternFunc(
"runtime.TVMArrayCreateView",
R.Callable(
derive_func="tvm.relax.struct_info.infer_view_sinfo",
purity=True,
),
)(
A,
R.shape([4096]),
R.dtype("int32"),
R.prim_value(0),
B = R.memory.view(
A, dtype=R.dtype("int32"), shape=R.shape([4096]), relative_byte_offset=0
)
return B

Expand All @@ -575,17 +544,8 @@ def main(A: R.Tensor([4096], "float32")):
class Expected:
@R.function
def main(A: R.Tensor([4096], "float32")):
B = R.ExternFunc(
"runtime.TVMArrayCreateView",
R.Callable(
derive_func="tvm.relax.struct_info.infer_view_sinfo",
purity=True,
),
)(
A,
R.shape([4096]),
R.dtype("float32"),
R.prim_value(0),
B = R.memory.view(
A, relative_byte_offset=R.prim_value(0), shape=R.shape([4096]), dtype="float32"
)
return B

Expand Down Expand Up @@ -624,29 +584,17 @@ def main(A: R.Tensor([4096], "uint8")):
class Expected:
@R.function
def main(A: R.Tensor([4096], "uint8")):
B = R.ExternFunc(
"runtime.TVMArrayCreateView",
R.Callable(
derive_func="tvm.relax.struct_info.infer_view_sinfo",
purity=True,
),
)(
B = R.memory.view(
A,
R.shape([512]),
R.dtype("int32"),
R.prim_value(0),
shape=R.shape([512]),
dtype=R.dtype("int32"),
relative_byte_offset=R.prim_value(0),
)
C = R.ExternFunc(
"runtime.TVMArrayCreateView",
R.Callable(
derive_func="tvm.relax.struct_info.infer_view_sinfo",
purity=True,
),
)(
C = R.memory.view(
A,
R.shape([16, 64]),
R.dtype("float16"),
R.prim_value(2048),
shape=R.shape([16, 64]),
dtype=R.dtype("float16"),
relative_byte_offset=R.prim_value(2048),
)
return (B, C)

Expand Down Expand Up @@ -772,5 +720,30 @@ def main(A: R.Tensor([4096], "uint8")):
tvm.testing.assert_allclose(tvm_output[1].numpy(), np_expected[1])


@tvm.testing.parametrize_targets("llvm", "cuda")
def test_execute_view_with_new_byte_offset_ensure_aligned(target, dev):
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([4096], "float32")):
B = R.memory.view(
A,
shape=R.shape([16, 64]),
relative_byte_offset=32 * 64 * 4,
)
C = R.memory.ensure_aligned(B)
return C

built = tvm.relax.build(Module, target=target)
vm = tvm.relax.VirtualMachine(built, device=dev)

np_input = np.random.random([4096]).astype("float32")
tvm_input = tvm.nd.array(np_input, dev)
tvm_output = vm["main"](tvm_input)
np_expected = np_input.reshape(64, 64)[32:48, :]

tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)


if __name__ == "__main__":
tvm.testing.main()
55 changes: 55 additions & 0 deletions tests/python/relax/test_transform_static_plan_block_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,5 +1449,60 @@ def main(
tvm.ir.assert_structural_equal(mod, Expected)


def test_view():
@I.ir_module
class Before:
@T.prim_func
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
T.evaluate(0)

@R.function
def main():
cls = Before
x = R.builtin.alloc_tensor(R.shape([16, 16]), dtype="float32", runtime_device_index=0)
x1 = R.memory.view(x, [128], "float32", 0)
x2 = R.memory.ensure_aligned(x1)
y = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0)
cls.tir_exp(x2, y)
z = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0)
cls.tir_exp(y, z)
return z

@I.ir_module
class Expected:
@T.prim_func
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
T.evaluate(0)

@R.function
def main() -> R.Tensor((128,), dtype="float32"):
cls = Module
storage: R.Object = R.memory.alloc_storage(
R.shape([1024]), R.prim_value(0), R.str("global"), R.dtype("float32")
)
x: R.Tensor((16, 16), dtype="float32") = R.memory.alloc_tensor(
storage, R.prim_value(0), R.shape([16, 16]), R.dtype("float32")
)
x1: R.Tensor((128,), dtype="float32") = R.memory.view(
x, R.shape([128]), R.dtype("float32"), R.prim_value(0)
)
x2: R.Tensor((128,), dtype="float32") = R.memory.ensure_aligned(x1)
storage1: R.Object = R.memory.alloc_storage(
R.shape([512]), R.prim_value(0), R.str("global"), R.dtype("float32")
)
y: R.Tensor((128,), dtype="float32") = R.memory.alloc_tensor(
storage1, R.prim_value(0), R.shape([128]), R.dtype("float32")
)
cls.tir_exp(x2, y)
z: R.Tensor((128,), dtype="float32") = R.builtin.alloc_tensor(
R.shape([128]), R.dtype("float32"), R.prim_value(0), R.str("global")
)
cls.tir_exp(y, z)
return z

after = relax.transform.StaticPlanBlockMemory()(Before)
tvm.ir.assert_structural_equal(after, Expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit a7556a8

Please sign in to comment.