Categories: Machine Learning

Restricted Boltzmann Machine with DeepLearnToolbox

In attempt to learn about deep learning’s fundamentals (and to fulfill a course assignment too), I tried to write a simple Restricted Boltzmann Machine (RBM) in GNU Octave by extracting Deep Belief Network (DBN) code example of DeepLearnToolbox. In case you are wondering, RBM is a machine learning algorithm that is promoted by Geoffrey Hinton as the basic of deep learning.

This code simply load MNIST digits handwriting dataset (provided by DeepLearnToolbox) and use it to train an RBM with 100 hidden nodes. The magic of RBM in this code happens within rbmtrain() function where the network is trained using Free Energy and Contrastive Divergence (CD). After 1 epoch of training, the weight of the RBM will be visualized and the reconstruction error will be printed.

function test_RBM 
t = time; 

# load dataset 
load mnist_uint8; 
train_x = double(train_x) / 255; 

# config 
opts.numepochs = 5; 
opts.batchsize = 100; 
opts.momentum = 0; 
opts.alpha = 1; 
rbm.sizes = [100]; 
rbm.alpha = opts.alpha; 
rbm.momentum = opts.momentum; 

# setup 
n = size(train_x, 2);
rbm.sizes = [n, rbm.sizes]; 

# weight 
rbm.W = zeros(rbm.sizes(2), 
rbm.sizes(1)); 
rbm.vW = zeros(rbm.sizes(2), rbm.sizes(1)); 

# biases 
rbm.b = zeros(rbm.sizes(1), 1); 
rbm.vb = zeros(rbm.sizes(1), 1); 
rbm.c = zeros(rbm.sizes(2), 1); 
rbm.vc = zeros(rbm.sizes(2), 1); 

# train 
rbm = rbmtrain(rbm, train_x, opts); 

# visualize 
figure; 
visualize(rbm.W'); 
disp(['elapsed time: ' num2str(time - t) 's']);

If the code runs correctly, you will see the weight visualization pops up.

MNIST RBM weight visualization

 

DeepLearnToolbox provides an easy-to-read code to help you understand deep learning algorithm better so I do encourage to drill down inside the API (eg. rbmtrain()). Since Octave/Matlab syntax is quite close to math formula, we can compare the API implementation directly to the math formula in theoretical explanation such as in http://deeplearning.net/tutorial/rbm.html.

0 0 votes
Article Rating
yohanes.gultom@gmail.com

Share
Published by
yohanes.gultom@gmail.com

Recent Posts

Get Unverified SSL Certificate Expiry Date with Python

Getting verified SSL information with Python (3.x) is very easy. Code examples for it are…

3 years ago

Spring Data Couchbase 4 Multibuckets in Spring Boot 2

By default, Spring Data Couchbase implements single-bucket configuration. In this default implementation, all POJO (Plain…

3 years ago

Firebase Auth Emulator with Python

Last year, Google released Firebase Auth Emulator as a new component in Firebase Emulator. In…

3 years ago

Google OIDC token generation/validation

One of the authentication protocol that is supported by most of Google Cloud services is…

4 years ago

Fast geolocation query with PostGIS

If you need to to add a spatial information querying in your application, PostGIS is…

4 years ago

Auto speech-to-text (Indonesian) with AWS Transcribe and Python

Amazon Web Service Transcribe provides API to automatically convert an audio speech file (mp3/wav) into…

4 years ago