Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make k2 ragged tensor more PyTorch-y like. #812

Merged
merged 29 commits into from
Sep 7, 2021
Merged

Conversation

csukuangfj
Copy link
Collaborator

This pull-request aims to kill k2.RaggedInt and k2.RaggedFloat.


Usage example

#!/usr/bin/env python3

import torch
import k2
import _k2

# TODO: will move _k2.ragged to k2.ragged

a = _k2.ragged.tensor([[1, 2], [3, 4.0]])
assert a.dtype == torch.float32

a = _k2.ragged.tensor([[1, 2], [3, 4]])
assert a.dtype == torch.int32

a = _k2.ragged.tensor([[1, 2], [3, 4]], dtype=torch.float64)
assert a.dtype == torch.float64
assert a.device == torch.device("cpu")

a = a.to(torch.device("cuda", 0))
assert a.device == torch.device("cuda", 0)

b = a.to(torch.int32)
b.dtype == torch.int32

a = b.to(torch.device("cpu")).to(torch.int64)
assert a.device == torch.device("cpu")
assert a.dtype == torch.int64

assert isinstance(a, k2.ragged.Tensor)

Will create a new class in C++ to wrap Ragged<Any>.

@csukuangfj
Copy link
Collaborator Author

A preview of the documentation can be found at
https://k3.readthedocs.io/en/latest/python_api/tensor.html

Will add usage examples to it once the code is finished.

@danpovey
Copy link
Collaborator

Cool!
Yes I think this is a good plan.
For others' clarity: the C++ class to wrap Ragged is so that we can support Torch-compatible backprop in C++. (We also have the option to do this in Python, but probably C++ will be more efficient and more future-proof when we want to do production stuff).

Likely this backprop would not be used for anything except float and double and half; we'd still do backprop for Ragged in Python, I assume, since it's probably better to only pass around the gradients for the scores only, not the whole Arc.


// AnyTensor is introduced to support backward propagations on
// RaggedAny since there has to be a tensor involved during backprob
class AnyTensor {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class still in a WIP status.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think RaggedAny would be clearer, since AnyTensor doesn't show a connection to "ragged", it looks like it could be a non-ragged tensor.


namespace k2 {

template <typename T>
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danpovey

This shows how to do autograd in C++, though the current implementation
does not give the correct gradient, but it produces a gradient, at least.

Still WIP.


The following screenshot is some test for the current commit:
Screen Shot 2021-08-26 at 23 52 02


template <typename T>
class SumFunction : public torch::autograd::Function<SumFunction<T>> {
static_assert(std::is_floating_point<T>::value);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discussed this in person with @csukuangfj ... I feel that we are pushing the templating "too far out" here, that this SumFunction does not need to be templated, and we can push the dispatching (i.e. if(float) {.. } else if(double) {...}) further inside by overloading the implementation of SumPerSublist for the type Any (that template would do the dispatching, likely via some macro).

@csukuangfj was concerned that sometimes these implementation wrappers will need to use K2_EVAL(...) for this or that purpose. My response was that we can probably wrap K2_EVAL(..) in some dispatching macro in most cases, and if that turns out to be difficult it's OK to push the dispatching further out, but I don't want to get into the habit of doing the dispatching too far out from the actual implementation code, because:
(i) it will lead to a lot of unnecessary binary code duplication, i.e. the compiler has to create many copies of the functions that aren't meaningfully different, and
(ii) in case we ever merge more closely with the Torch codebase, it will be better to stick to patterns more similar to what they use, and the Torch codebase pushes the dispatching very far in, to directly where the actual data processing happens.

So we agreed that we'll push the dispatching further in, by overloading SumPerSublist for type Any in this case; but I was open to having dispatching further out on a case by base basis in case this pattern proves to be hard to use in specific cases.


SumPerSublist<T>(any.any_.Specialize<T>(), initial_value, &values);
return ToTorch(values);
FOR_REAL_AND_INT32_TYPES(t, T, {
Copy link
Collaborator Author

@csukuangfj csukuangfj Aug 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dispatching decision is now made here.

As SumPerSublist calls the template SegmentedReduce,
maybe we should replace the template SegmentedReduce with a non-templated
version accepting RaggedAny and do dispatching inside it.

It can reduce compilation time and reduce the size of the shared library, I think.

const T *grad_output_data = grad_output.data_ptr<T>();
T *ans_data = ans.data_ptr<T>();

K2_EVAL(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling a macro inside another macro is not that ugly, for this specific case, I think.

namespace k2 {

static constexpr const char *kRaggedAnyInitDataDoc = R"doc(
Create a ragged tensor with two axes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to put the documentation in C++ headers.
Users can view the help information in a usual way, i.e.,

>>> import k2.ragged
>>> help(k2.ragged.Tensor.__init__)

The output is given below:
Screen Shot 2021-08-28 at 18 14 17


Also, the doc is going to be hosted at
https://k2-fsa.github.io/k2/index.html

The reason is that now the doc depends on the C++ source code and we have to compile
_k2 to generate the doc.


A preview is available at
https://csukuangfj.github.io/k2/python_api/tensor.html#k2.ragged.Tensor

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc is easier to discover as it is bound to the actual class. There is no need to create a fake class just for documentation purposes.

@danpovey
Copy link
Collaborator

Great!

it throws.

>>> import torch
impor>>> import k2.ragged as k2r
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete prefix 'impor'.


An example string for a 3-axis ragged tensor is given below::

[ [[1] [2 3]] [[2] [] [3, 4,]] ]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no comma, [ [[1] [2 3]] [[2] [] [3, 4,]] ] > [ [[1] [2 3]] [[2] [] [3 4]] ]


Caution:
Currently, only support for dtypes ``torch.int32``, ``torch.float32``, and
``torch.float64`` are implemented. We can support other types if needed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete "are implemented"

>>> a = k2r.Tensor([[1], [], [3, 4, 5, 6]])
>>> a.numel()
5
>>> b = k2r.Tensor('[ [[1] [] []] [[2 3]]]')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to use the same constructor in one example.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed.

any.def("remove_values_eq", &RaggedAny::RemoveValuesEq, py::arg("target"));
any.def("argmax_per_sublist", &RaggedAny::ArgMaxPerSublist,
py::arg("initial_value"));
any.def("max_per_sublist", &RaggedAny::MaxPerSublist,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make the interface more torch-like, I think we should change this to max and add one more parameter like axis, as dan suggested before. So as other sublist operators.

Of course, we could keep two interface if back compatible is needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I think that would be a good idea.
It's OK to have an arg for axis, but only support -1 (or num_axes-1) for now.
Incidentally, in torch, axis is called dim. We should probably stay consistent within k2 for now though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree to remove the per-sublist part.

All operations on the values are on the last axis, I think. Is there a need to add an extra argument axis?
How is it supposed to be used by users?

Copy link
Collaborator

@pkufool pkufool Sep 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remembered that Dan said we would have a plan to support the operations on other axis, we could add this argument and set the default value to -1. Or, we could add this argument when needed.

@@ -686,72 +699,93 @@ static void PybindReplaceFsa(py::module &m) {
}

static void PybindCtcGraph(py::module &m) {
m.def(
"ctc_graph",
[](RaggedAny &symbols, torch::optional<torch::Device> = {},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need device argument here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose is to make the call k2.linear_fsa in Python as uniform as possible: You don't need to know the type of lables.

ragged_arc = _k2.linear_fsa(labels, device)

context = GetCpuContext();
else
context = GetCudaContext(gpu_id);
torch::optional<torch::Device> device = {}, bool modified = false,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the benefit of passing argument device by torch::Device and std::string, why not use py::object to avoid implementing it twice.

we can get context like the code below, whether the passing in argument is cuda:0 or torch.device.

ContextPtr GetContext(py::object device_obj) {
  auto device = torch::Device(static_cast<py::str>(device_obj));
  if (device.type() == torch::kCPU) {
    return GetCpuContext();
  } else if (device.type() == torch::kCUDA) {
    return GetCudaContext(device.index());
  } else {
    K2_LOG(FATAL) << "Unsupported device: " << device.str();
    return GetCpuContext();   // unreachable code
  }
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, your approach would work, but it produces less beautiful documentation.

The current implementation produces the following doc

import _k2
help(_k2.linear_fsa)
linear_fsa(...) method of builtins.PyCapsule instance
    linear_fsa(*args, **kwargs)
    Overloaded function.

    1. linear_fsa(labels: k2::RaggedAny, device: Optional[torch::Device] = None) -> k2::Ragged<k2::Arc>

    2. linear_fsa(labels: List[int], device: Optional[torch::Device] = None) -> k2::Ragged<k2::Arc>

    3. linear_fsa(labels: List[int], device: Optional[str] = None) -> k2::Ragged<k2::Arc>

    4. linear_fsa(labels: List[List[int]], device: Optional[torch::Device] = None) -> k2::Ragged<k2::Arc>

    5. linear_fsa(labels: List[List[int]], device: Optional[str] = None) -> k2::Ragged<k2::Arc>

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!Thanks for explanation!

@pkufool pkufool mentioned this pull request Sep 7, 2021
4 tasks
@csukuangfj csukuangfj changed the title WIP: Make k2 ragged tensor more PyTorch-y like. Make k2 ragged tensor more PyTorch-y like. Sep 7, 2021
@csukuangfj
Copy link
Collaborator Author

Ready for review and merge.

I know it contains lots of changes, though most of them are documentation (several thousands of lines of documentation).

All existing test cases are passed. No test case is removed.


We need to add more tests to k2/python/tests/ragged_tensor_test.py.

Help is wanted (A single test function added to that file is also appreciated)

@csukuangfj
Copy link
Collaborator Author

Will add more tutorials to docs/source/python_tutorials/ragged when I have time.

@danpovey
Copy link
Collaborator

danpovey commented Sep 7, 2021

Great work!
I looked it over briefly and it looks great.
I think it's OK to merge.

@csukuangfj
Copy link
Collaborator Author

Merging.

Will release a new version tomorrow.

The documentation of this pull request can be found at
https://k2-fsa.github.io/k2/python_api/api.html#k2-ragged

It contains usage examples for all most every API.
Screen Shot 2021-09-07 at 8 30 53 PM

@csukuangfj csukuangfj merged commit fbb10a0 into k2-fsa:master Sep 7, 2021
@csukuangfj csukuangfj deleted the any branch September 7, 2021 12:32
This was referenced Nov 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants