Skip to content

Commit

Permalink
Sample python bilinear initializer at integral points in y-direction (a…
Browse files Browse the repository at this point in the history
…pache#12983)

* Sample python bilinear initializer at integral points in y-direction

* Add unit test for bilinear initializer
  • Loading branch information
vladoovtcharov authored and haohuw committed Jun 23, 2019
1 parent 5ce77f2 commit 88ac320
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _init_bilinear(self, _, arr):
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(np.prod(shape)):
x = i % shape[3]
y = (i / shape[3]) % shape[2]
y = (i // shape[3]) % shape[2]
weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
arr[:] = weight.reshape(shape)

Expand Down Expand Up @@ -657,7 +657,7 @@ def _init_weight(self, _, arr):
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(np.prod(shape)):
x = i % shape[3]
y = (i / shape[3]) % shape[2]
y = (i // shape[3]) % shape[2]
weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
arr[:] = weight.reshape(shape)

Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,17 @@ def check_rsp_const_init(init, val):
check_rsp_const_init(mx.initializer.Zero(), 0.)
check_rsp_const_init(mx.initializer.One(), 1.)

def test_bilinear_init():
bili = mx.init.Bilinear()
bili_weight = mx.ndarray.empty((1,1,4,4))
bili._init_weight(None, bili_weight)
bili_1d = np.array([[1/float(4), 3/float(4), 3/float(4), 1/float(4)]])
bili_2d = bili_1d * np.transpose(bili_1d)
assert (bili_2d == bili_weight.asnumpy()).all()

if __name__ == '__main__':
test_variable_init()
test_default_init()
test_aux_init()
test_rsp_const_init()
test_bilinear_init()

0 comments on commit 88ac320

Please sign in to comment.