This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Feature Request] np.where compatible operator #16101
Labels
Comments
Hey, this is the MXNet Label Bot. |
@mxnet-label-bot add [Bug, Feature request] |
Scalar tensors like below is not supported in y = mx.nd.array(4) # y.shape: () |
reminisce
changed the title
[Bug, Feature Request] mx.nd.where()
[Feature Request] np.where compatible operator
Sep 5, 2019
#16829 merged, closing the issue. |
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Description
Bug
mx.nd.where() shows an incorrect behavior when one of the inputs is an NDArray with zero size.
Here is a reproducible example
The output is weird and it seems that the NDArray with zero size has not been checked. We expect that it would raise an error showing the shape of x and y must be the same, according to docs of mx.nd.where(). Broadcast is not supported in the latest version but where() still has an output.
It is also a little dangerous as it outputs incorrect answers rather than error messages, when users forget to type [] for
mx.nd.array([4])
.Feature Request
1. Broadcast
Currently, there are two limitations for mx.nd.where()
Similar to np.where(), it would be great if mx.nd.where() supports broadcast to make sure (cond, x, y) have the same shape, even if they are in different shapes as input.
2. Scalar inputs (cond, x and y)
In some situations, we want to give a constant value for True/False.
It would be user-friendly if programmers only need to type
mx.nd.where(cond, x, 0)
instead of
mx.nd.where(cond, x, mx.nd.array([0]))
Environment info (Required)
Build info (Required if built from source)
Compiler (gcc/clang/mingw/visual studio): gcc
MXNet commit hash: 03f12f0fe706d35c93a2cf721b6101bcbffeb07d
Build config: plain CMakeList.txt with USE_NCCL=1
The text was updated successfully, but these errors were encountered: