Git Product home page Git Product logo

restricted-boltzmann-machine's Introduction

Restricted Boltzmann Machine

all

(The first column is the input. Other columns are reconstructed outputs.)

Requirements

python 3.6

numpy >= 1.14

matplotlib >= 3.0.0

How to use

mnist_bin.npy is an numpy binary file downloaded from Mnist or github source, which contains 6 million images of hand written digits (0 - 9), with 28x28 as image shape. Load this binary using numpy.

import numpy as np
mnist = np.load('mnist_bin.npy')  # 60000x28x28

To use RBM from rbm.py, specify the number of hidden and visible units in initialization.

rbm = RBM(n_hidden=100, m_observe=28 * 28)

Train the RBM with train method, and feed it with data.

rbm.train(mnist[:200], epochs=10)

After training, you can sample from RBM. What you get should be an image of a hand written digit generated by the model, which is not in the origin dataset. Usually, a good initial image produces better results than random initialized inputs.

v = rbm.sample(num_iter=200, v_init=mnist[0])

Visualize the output with matplotlib.

plt.imshow(v.reshape((28, 28)), cmap="gray")
plt.show()

Image of v

The full script:

import numpy as np
import matplotlib.pyplot as plt

mnist = np.load('mnist_bin.npy')  # 60000x28x28
n_imgs, n_rows, n_cols = mnist.shape
img_size = n_rows * n_cols
print(mnist.shape)

# construct rbm model
rbm = RBM(n_hidden=100, m_observe=28 * 28)

print("Start RBM training.")
# train rbm model using mnist
rbm.train(mnist[:200], epochs=10)
print("Finish RBM training.")

# sample from rbm model
v = rbm.sample(num_iter=200, v_init=mnist[0])
plt.imshow(v.reshape((28, 28)), cmap="gray")
plt.show()

For details about RBM, refer to this report.

restricted-boltzmann-machine's People

Contributors

fengziyjun avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

jtk001

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.