-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make k2 ragged tensor more PyTorch-y like. #812
Conversation
A preview of the documentation can be found at Will add usage examples to it once the code is finished. |
Cool! Likely this backprop would not be used for anything except float and double and half; we'd still do backprop for Ragged in Python, I assume, since it's probably better to only pass around the gradients for the scores only, not the whole Arc. |
k2/python/csrc/torch/any_tensor.h
Outdated
|
||
// AnyTensor is introduced to support backward propagations on | ||
// RaggedAny since there has to be a tensor involved during backprob | ||
class AnyTensor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class still in a WIP status.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think RaggedAny would be clearer, since AnyTensor doesn't show a connection to "ragged", it looks like it could be a non-ragged tensor.
k2/python/csrc/torch/autograd/sum.h
Outdated
|
||
namespace k2 { | ||
|
||
template <typename T> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
k2/python/csrc/torch/autograd/sum.h
Outdated
|
||
template <typename T> | ||
class SumFunction : public torch::autograd::Function<SumFunction<T>> { | ||
static_assert(std::is_floating_point<T>::value); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I discussed this in person with @csukuangfj ... I feel that we are pushing the templating "too far out" here, that this SumFunction does not need to be templated, and we can push the dispatching (i.e. if(float) {.. } else if(double) {...}) further inside by overloading the implementation of SumPerSublist for the type Any (that template would do the dispatching, likely via some macro).
@csukuangfj was concerned that sometimes these implementation wrappers will need to use K2_EVAL(...) for this or that purpose. My response was that we can probably wrap K2_EVAL(..) in some dispatching macro in most cases, and if that turns out to be difficult it's OK to push the dispatching further out, but I don't want to get into the habit of doing the dispatching too far out from the actual implementation code, because:
(i) it will lead to a lot of unnecessary binary code duplication, i.e. the compiler has to create many copies of the functions that aren't meaningfully different, and
(ii) in case we ever merge more closely with the Torch codebase, it will be better to stick to patterns more similar to what they use, and the Torch codebase pushes the dispatching very far in, to directly where the actual data processing happens.
So we agreed that we'll push the dispatching further in, by overloading SumPerSublist for type Any in this case; but I was open to having dispatching further out on a case by base basis in case this pattern proves to be hard to use in specific cases.
k2/python/csrc/torch/autograd/sum.h
Outdated
|
||
SumPerSublist<T>(any.any_.Specialize<T>(), initial_value, &values); | ||
return ToTorch(values); | ||
FOR_REAL_AND_INT32_TYPES(t, T, { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dispatching decision is now made here.
As SumPerSublist
calls the template SegmentedReduce
,
maybe we should replace the template SegmentedReduce
with a non-templated
version accepting RaggedAny
and do dispatching inside it.
It can reduce compilation time and reduce the size of the shared library, I think.
k2/python/csrc/torch/autograd/sum.h
Outdated
const T *grad_output_data = grad_output.data_ptr<T>(); | ||
T *ans_data = ans.data_ptr<T>(); | ||
|
||
K2_EVAL( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling a macro inside another macro is not that ugly, for this specific case, I think.
k2/python/csrc/torch/doc/any.h
Outdated
namespace k2 { | ||
|
||
static constexpr const char *kRaggedAnyInitDataDoc = R"doc( | ||
Create a ragged tensor with two axes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am trying to put the documentation in C++ headers.
Users can view the help information in a usual way, i.e.,
>>> import k2.ragged
>>> help(k2.ragged.Tensor.__init__)
Also, the doc is going to be hosted at
https://k2-fsa.github.io/k2/index.html
The reason is that now the doc depends on the C++ source code and we have to compile
_k2
to generate the doc.
A preview is available at
https://csukuangfj.github.io/k2/python_api/tensor.html#k2.ragged.Tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The doc is easier to discover as it is bound to the actual class. There is no need to create a fake class just for documentation purposes.
Great! |
k2/python/csrc/torch/doc/any.h
Outdated
it throws. | ||
|
||
>>> import torch | ||
impor>>> import k2.ragged as k2r |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete prefix 'impor'.
k2/python/csrc/torch/doc/any.h
Outdated
|
||
An example string for a 3-axis ragged tensor is given below:: | ||
|
||
[ [[1] [2 3]] [[2] [] [3, 4,]] ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no comma, [ [[1] [2 3]] [[2] [] [3, 4,]] ] > [ [[1] [2 3]] [[2] [] [3 4]] ]
k2/python/csrc/torch/doc/any.h
Outdated
|
||
Caution: | ||
Currently, only support for dtypes ``torch.int32``, ``torch.float32``, and | ||
``torch.float64`` are implemented. We can support other types if needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete "are implemented"
k2/python/csrc/torch/doc/any.h
Outdated
>>> a = k2r.Tensor([[1], [], [3, 4, 5, 6]]) | ||
>>> a.numel() | ||
5 | ||
>>> b = k2r.Tensor('[ [[1] [] []] [[2 3]]]') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to use the same constructor in one example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed.
k2/python/csrc/torch/v2/any.cu
Outdated
any.def("remove_values_eq", &RaggedAny::RemoveValuesEq, py::arg("target")); | ||
any.def("argmax_per_sublist", &RaggedAny::ArgMaxPerSublist, | ||
py::arg("initial_value")); | ||
any.def("max_per_sublist", &RaggedAny::MaxPerSublist, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make the interface more torch-like, I think we should change this to max
and add one more parameter like axis
, as dan suggested before. So as other sublist operators.
Of course, we could keep two interface if back compatible is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I think that would be a good idea.
It's OK to have an arg for axis, but only support -1 (or num_axes-1) for now.
Incidentally, in torch, axis
is called dim
. We should probably stay consistent within k2 for now though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree to remove the per-sublist
part.
All operations on the values
are on the last axis, I think. Is there a need to add an extra argument axis
?
How is it supposed to be used by users?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remembered that Dan said we would have a plan to support the operations on other axis, we could add this argument and set the default value to -1. Or, we could add this argument when needed.
@@ -686,72 +699,93 @@ static void PybindReplaceFsa(py::module &m) { | |||
} | |||
|
|||
static void PybindCtcGraph(py::module &m) { | |||
m.def( | |||
"ctc_graph", | |||
[](RaggedAny &symbols, torch::optional<torch::Device> = {}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need device argument here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose is to make the call k2.linear_fsa
in Python as uniform as possible: You don't need to know the type of lables
.
Line 66 in 8988489
ragged_arc = _k2.linear_fsa(labels, device) |
context = GetCpuContext(); | ||
else | ||
context = GetCudaContext(gpu_id); | ||
torch::optional<torch::Device> device = {}, bool modified = false, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the benefit of passing argument device by torch::Device
and std::string
, why not use py::object
to avoid implementing it twice.
we can get context like the code below, whether the passing in argument is cuda:0
or torch.device
.
ContextPtr GetContext(py::object device_obj) {
auto device = torch::Device(static_cast<py::str>(device_obj));
if (device.type() == torch::kCPU) {
return GetCpuContext();
} else if (device.type() == torch::kCUDA) {
return GetCudaContext(device.index());
} else {
K2_LOG(FATAL) << "Unsupported device: " << device.str();
return GetCpuContext(); // unreachable code
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, your approach would work, but it produces less beautiful documentation.
The current implementation produces the following doc
import _k2
help(_k2.linear_fsa)
linear_fsa(...) method of builtins.PyCapsule instance
linear_fsa(*args, **kwargs)
Overloaded function.
1. linear_fsa(labels: k2::RaggedAny, device: Optional[torch::Device] = None) -> k2::Ragged<k2::Arc>
2. linear_fsa(labels: List[int], device: Optional[torch::Device] = None) -> k2::Ragged<k2::Arc>
3. linear_fsa(labels: List[int], device: Optional[str] = None) -> k2::Ragged<k2::Arc>
4. linear_fsa(labels: List[List[int]], device: Optional[torch::Device] = None) -> k2::Ragged<k2::Arc>
5. linear_fsa(labels: List[List[int]], device: Optional[str] = None) -> k2::Ragged<k2::Arc>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!Thanks for explanation!
Replace _k2.ragged with k2.ragged and replace at::Tensor with torch.Tensor.
Ready for review and merge. I know it contains lots of changes, though most of them are documentation (several thousands of lines of documentation). All existing test cases are passed. No test case is removed. We need to add more tests to Help is wanted (A single test function added to that file is also appreciated) |
Will add more tutorials to |
Great work! |
Merging. Will release a new version tomorrow. The documentation of this pull request can be found at |
This pull-request aims to kill
k2.RaggedInt
andk2.RaggedFloat
.Usage example
Will create a new class in C++ to wrap
Ragged<Any>
.