Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Feature Request] np.where compatible operator #16101

Closed
chongruo opened this issue Sep 5, 2019 · 4 comments
Closed

[Feature Request] np.where compatible operator #16101

chongruo opened this issue Sep 5, 2019 · 4 comments

Comments

@chongruo
Copy link
Contributor

chongruo commented Sep 5, 2019

Description

Bug

mx.nd.where() shows an incorrect behavior when one of the inputs is an NDArray with zero size.

Here is a reproducible example

cond = mx.nd.array([0])       # cond.shape: (1,)
x = mx.nd.array([[10,10]])    #    x.shape: (1, 2)
y = mx.nd.array(4)            #    y.shape: ()

print( mx.nd.where(cond, x, y) )
# output: [[4.0000e+00 3.0773e-41]]

The output is weird and it seems that the NDArray with zero size has not been checked. We expect that it would raise an error showing the shape of x and y must be the same, according to docs of mx.nd.where(). Broadcast is not supported in the latest version but where() still has an output.

It is also a little dangerous as it outputs incorrect answers rather than error messages, when users forget to type [] for mx.nd.array([4]).


Feature Request

1. Broadcast

Currently, there are two limitations for mx.nd.where()

  • x and y must have the same shape
  • If condition does not have the same shape as x, it must be a 1D array whose size is the same as x’s first dimension size

Similar to np.where(), it would be great if mx.nd.where() supports broadcast to make sure (cond, x, y) have the same shape, even if they are in different shapes as input.


2. Scalar inputs (cond, x and y)

In some situations, we want to give a constant value for True/False.

It would be user-friendly if programmers only need to type
mx.nd.where(cond, x, 0)
instead of
mx.nd.where(cond, x, mx.nd.array([0]))







Environment info (Required)

----------Python Info----------
Version      : 3.6.9
Compiler     : GCC 7.3.0
Build        : ('default', 'Jul 30 2019 19:07:31')
Arch         : ('64bit', '')
------------Pip Info-----------
Version      : 19.2.2
Directory    : /home/ubuntu/anaconda3/envs/new/lib/python3.6/site-packages/pip
----------MXNet Info-----------
Version      : 1.6.0
Directory    : /home/ubuntu/new/my-mxnet/python/mxnet
Commit hash file "/home/ubuntu/new/my-mxnet/python/mxnet/COMMIT_HASH" not found. Not installed from pre-built package or built from source.
Library      : ['/home/ubuntu/new/my-mxnet/python/mxnet/../../build/libmxnet.so']
Build features:
No runtime build feature info available
----------System Info----------
Platform     : Linux-4.4.0-1092-aws-x86_64-with-debian-stretch-sid
system       : Linux
node         : ip-172-31-14-150
release      : 4.4.0-1092-aws
version      : #103-Ubuntu SMP Tue Aug 27 10:21:48 UTC 2019
----------Hardware Info----------
machine      : x86_64
processor    : x86_64
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                96
On-line CPU(s) list:   0-95
Thread(s) per core:    2
Core(s) per socket:    24
Socket(s):             2
NUMA node(s):          2
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 85
Model name:            Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz
Stepping:              4
CPU MHz:               2499.998
BogoMIPS:              4999.99
Hypervisor vendor:     KVM
Virtualization type:   full
L1d cache:             32K
L1i cache:             32K
L2 cache:              1024K
L3 cache:              33792K
NUMA node0 CPU(s):     0-23,48-71
NUMA node1 CPU(s):     24-47,72-95
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f rdseed adx smap clflushopt clwb avx512cd xsaveopt xsavec xgetbv1 ida arat pku
----------Network Test----------
Setting timeout: 10
Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0014 sec, LOAD: 0.4787 sec.
Timing for Gluon Tutorial(en): https://gluon.mxnet.io, DNS: 0.1707 sec, LOAD: 0.2402 sec.
Timing for Gluon Tutorial(cn): https://zh.gluon.ai, DNS: 0.0228 sec, LOAD: 0.3108 sec.
Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.0107 sec, LOAD: 0.1101 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0013 sec, LOAD: 0.3356 sec.
Timing for Conda: https://repo.continuum.io/pkgs/free/, DNS: 0.0135 sec, LOAD: 0.0633 sec.
----------Environment----------

Build info (Required if built from source)

Compiler (gcc/clang/mingw/visual studio): gcc

MXNet commit hash: 03f12f0fe706d35c93a2cf721b6101bcbffeb07d

Build config: plain CMakeList.txt with USE_NCCL=1

@mxnet-label-bot
Copy link
Contributor

Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended label(s): Feature

@chongruo
Copy link
Contributor Author

chongruo commented Sep 5, 2019

@mxnet-label-bot add [Bug, Feature request]

@reminisce
Copy link
Contributor

Scalar tensors like below is not supported in mx.nd module. We need to implement a numpy-compatible where op for this purpose. I will add this op to the list and prioritize it.

y = mx.nd.array(4)            #    y.shape: ()

@reminisce reminisce removed the Bug label Sep 5, 2019
@reminisce reminisce changed the title [Bug, Feature Request] mx.nd.where() [Feature Request] np.where compatible operator Sep 5, 2019
@zachgk zachgk added the Numpy label Sep 16, 2019
@haojin2
Copy link
Contributor

haojin2 commented Nov 19, 2019

#16829 merged, closing the issue.

@haojin2 haojin2 closed this as completed Nov 19, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

6 participants