(The first column is the input. Other columns are reconstructed outputs.)
python 3.6
numpy >= 1.14
matplotlib >= 3.0.0
mnist_bin.npy
is an numpy binary file downloaded from Mnist or github source, which contains 6 million images of hand written digits (0 - 9), with 28x28 as image shape.
Load this binary using numpy.
import numpy as np
mnist = np.load('mnist_bin.npy') # 60000x28x28
To use RBM from rbm.py
, specify the number of hidden and visible units in initialization.
rbm = RBM(n_hidden=100, m_observe=28 * 28)
Train the RBM with train
method, and feed it with data.
rbm.train(mnist[:200], epochs=10)
After training, you can sample from RBM. What you get should be an image of a hand written digit generated by the model, which is not in the origin dataset. Usually, a good initial image produces better results than random initialized inputs.
v = rbm.sample(num_iter=200, v_init=mnist[0])
Visualize the output with matplotlib.
plt.imshow(v.reshape((28, 28)), cmap="gray")
plt.show()
The full script:
import numpy as np
import matplotlib.pyplot as plt
mnist = np.load('mnist_bin.npy') # 60000x28x28
n_imgs, n_rows, n_cols = mnist.shape
img_size = n_rows * n_cols
print(mnist.shape)
# construct rbm model
rbm = RBM(n_hidden=100, m_observe=28 * 28)
print("Start RBM training.")
# train rbm model using mnist
rbm.train(mnist[:200], epochs=10)
print("Finish RBM training.")
# sample from rbm model
v = rbm.sample(num_iter=200, v_init=mnist[0])
plt.imshow(v.reshape((28, 28)), cmap="gray")
plt.show()
For details about RBM, refer to this report.