Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

mxnet.io.NDArrayIter does not properly function when len(data) * 2 < batch_size and last_batch_handle='pad' #15535

Closed
turtleizzy opened this issue Jul 14, 2019 · 6 comments · Fixed by #16166

Comments

@turtleizzy
Copy link

mxnet version: 1.4.0
minimal steps to reproduce:

data = mx.nd.arange(4)
dtIter = mx.io.NDArrayIter(data, batch_size=9, last_batch_handle='pad')
for i in dtIter:
    print (i.data)
IndexError                                Traceback (most recent call last)
<ipython-input-76-d82503158b2b> in <module>
----> 1 for i in dtIter:
      2     print (i.data)

/usr/local/lib/python3.6/site-packages/mxnet/io/io.py in __next__(self)
    226 
    227     def __next__(self):
--> 228         return self.next()
    229 
    230     def iter_next(self):

/usr/local/lib/python3.6/site-packages/mxnet/io/io.py in next(self)
    678         if not self.iter_next():
    679             raise StopIteration
--> 680         data = self.getdata()
    681         label = self.getlabel()
    682         # iter should stop when last batch is not complete

/usr/local/lib/python3.6/site-packages/mxnet/io/io.py in getdata(self)
    760     def getdata(self):
    761         """Get data."""
--> 762         return self._batchify(self.data)
    763 
    764     def getlabel(self):

/usr/local/lib/python3.6/site-packages/mxnet/io/io.py in _batchify(self, data_source)
    747             pad = self.batch_size - self.num_data + self.cursor
    748             first_data = self._getdata(data_source, start=self.cursor)
--> 749             second_data = self._getdata(data_source, end=pad)
    750             return self._concat(first_data, second_data)
    751         # normal case

/usr/local/lib/python3.6/site-packages/mxnet/io/io.py in _getdata(self, data_source, start, end)
    703                 list(self.idx[s]).index(i)
    704                 for i in sorted(self.idx[s])
--> 705             ]]) for x in data_source
    706         ]
    707 

/usr/local/lib/python3.6/site-packages/mxnet/io/io.py in <listcomp>(.0)
    703                 list(self.idx[s]).index(i)
    704                 for i in sorted(self.idx[s])
--> 705             ]]) for x in data_source
    706         ]
    707 

/usr/local/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py in __getitem__(self, key)
    504         indexing_dispatch_code = _get_indexing_dispatch_code(key)
    505         if indexing_dispatch_code == _NDARRAY_BASIC_INDEXING:
--> 506             return self._get_nd_basic_indexing(key)
    507         elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
    508             return self._get_nd_advanced_indexing(key)

/usr/local/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py in _get_nd_basic_indexing(self, key)
    785                 return op.slice(self, begin=(key.start,), end=(key.stop,), step=(key.step,))
    786             elif key.start is not None or key.stop is not None:
--> 787                 return self._slice(key.start, key.stop)
    788             else:
    789                 return self

/usr/local/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py in _slice(self, start, stop)
    900         """
    901         handle = NDArrayHandle()
--> 902         start, stop, _ = _get_index_range(start, stop, self.shape[0])
    903 
    904         check_call(_LIB.MXNDArraySlice(

/usr/local/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py in _get_index_range(start, stop, length, step)
   2325             raise IndexError('Slicing stop %d exceeds limit of %d' % (stop-length, length))
   2326     elif stop > length:
-> 2327         raise IndexError('Slicing stop %d exceeds limit of %d' % (stop, length))
   2328 
   2329     return start, stop, step

IndexError: Slicing stop 5 exceeds limit of 4
@mxnet-label-bot
Copy link
Contributor

Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended labels: Bug

@braindotai
Copy link

braindotai commented Jul 14, 2019

I was running your code, and it looks like last_batch_handle has nothing to do with the error. You'll get the same error whenever your batch size > 2 * len(data).

data = mx.nd.arange(5)

dtIter = mx.io.NDArrayIter(data, batch_size=11) # everything is fine for batch_size upto 10
for i in dtIter:
    print (i.data)

prints:
IndexError: Slicing stop 6 exceeds limit of 5

@turtleizzy
Copy link
Author

I was running your code, and it looks like last_batch_handle has nothing to do with the error. You'll get the same error whenever your batch size > 2 * len(data).

data = mx.nd.arange(5)

dtIter = mx.io.NDArrayIter(data, batch_size=11) # everything is fine for batch_size upto 10
for i in dtIter:
    print (i.data)

prints:
IndexError: Slicing stop 6 exceeds limit of 5

last_batch_handle='pad' does matter because the default value for last_batch_handle is 'pad' and all other options (roll_over and discard) does not result in exception.

@braindotai
Copy link

Oh man!! Sorry, I forgot to see what's the default.

@turtleizzy
Copy link
Author

PS: workaround is to pad manually.

@frankfliu
Copy link
Contributor

@mxnet-label-bot add [python, bug]

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants