Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch module emits UserWarning #66

Open
pzelasko opened this issue Apr 10, 2022 · 8 comments
Open

PyTorch module emits UserWarning #66

pzelasko opened this issue Apr 10, 2022 · 8 comments

Comments

@pzelasko
Copy link

pzelasko commented Apr 10, 2022

When running nara_wpe.torch_wpe.wpe_v6 with PyTorch 1.11, I'm seeing the following warning:

torch.linalg.solve has its arguments reversed and does not return the LU factorization.
To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.
X = torch.solve(B, A).solution
should be replaced with
X = torch.linalg.solve(A, B) (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:766.)
  G, _ = torch.solve(P, R)

It seems modifying G, _ = torch.solve(P, R) to G = torch.linalg.solve(R, P) does the trick.

@pzelasko
Copy link
Author

It'd also be cool if you could release v0.0.8 in PyPI which contains the torch_wpe submodule; I'm adding support for WPE as data augmentation in Lhotse which leverages your library here lhotse-speech/lhotse#663

@boeddeker
Copy link
Member

I published that torch code, because the pytorch people wanted to have an example, that uses complex numbers.
In #46 I accidentally merged the code and forgot it. I checked the code, and it contains indeed an error. The hermite operation does no conjugate. I will fix it.

In General, I would recommend using the numpy code instead of the torch code, it is

  • much more tested and used,
  • has different implementations that are faster in different situations or use less memory,
  • has a stabilization implemented, (Not sure, but I got the impression, that numpy has better implementations for solve, ...)
  • and it is not constraint to be differentiable.

Nevertheless, I should fix the torch code in nara_wpe. With the newer torch versions, all required operations are implemented.

@pzelasko
Copy link
Author

pzelasko commented Apr 11, 2022

Thank you!

BTW You might find it interesting -- I performed a simple benchmark on a single utterance with Jupyter's %%timeit and saw that the numpy version took 300ms on average, while torch implementation took 130ms on average (on CPU). It's likely partially explained by the missing conjugate, but still seems worthy of attention.

@boeddeker
Copy link
Member

The conjugate should not have this effect. I think it comes from the solve operation or the memory view. Pytorch does not support negative strides, hence I used a view that does not match the theory, but produces the correct final result.
Thanks for reporting this, I am not sure, when I will find the time to investigate this. Is your benchmark a shareable toy example?

@pzelasko
Copy link
Author

I'll try to find a moment to post it tomorrow.

@pzelasko
Copy link
Author

@boeddeker
Copy link
Member

I fixed the torch wpe code, removed several deprecation warnings and pushed a new version to PyPI.

Thanks for sharing the notebook. I am not yet sure, when I find the time to check it and see if I can speedup the numpy code.

@pzelasko
Copy link
Author

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants