In attempt to learn about deep learning’s fundamentals (and to fulfill a course assignment too), I tried to write a simple Restricted Boltzmann Machine (RBM) in GNU Octave by extracting Deep Belief Network (DBN) code example of DeepLearnToolbox. In case you are wondering, RBM is a machine learning algorithm that is promoted by Geoffrey Hinton as the basic of deep learning.
This code simply load MNIST digits handwriting dataset (provided by DeepLearnToolbox) and use it to train an RBM with 100 hidden nodes. The magic of RBM in this code happens within rbmtrain()
function where the network is trained using Free Energy and Contrastive Divergence (CD). After 1 epoch of training, the weight of the RBM will be visualized and the reconstruction error will be printed.
function test_RBM t = time; # load dataset load mnist_uint8; train_x = double(train_x) / 255; # config opts.numepochs = 5; opts.batchsize = 100; opts.momentum = 0; opts.alpha = 1; rbm.sizes = [100]; rbm.alpha = opts.alpha; rbm.momentum = opts.momentum; # setup n = size(train_x, 2); rbm.sizes = [n, rbm.sizes]; # weight rbm.W = zeros(rbm.sizes(2), rbm.sizes(1)); rbm.vW = zeros(rbm.sizes(2), rbm.sizes(1)); # biases rbm.b = zeros(rbm.sizes(1), 1); rbm.vb = zeros(rbm.sizes(1), 1); rbm.c = zeros(rbm.sizes(2), 1); rbm.vc = zeros(rbm.sizes(2), 1); # train rbm = rbmtrain(rbm, train_x, opts); # visualize figure; visualize(rbm.W'); disp(['elapsed time: ' num2str(time - t) 's']);
If the code runs correctly, you will see the weight visualization pops up.
DeepLearnToolbox provides an easy-to-read code to help you understand deep learning algorithm better so I do encourage to drill down inside the API (eg. rbmtrain()
). Since Octave/Matlab syntax is quite close to math formula, we can compare the API implementation directly to the math formula in theoretical explanation such as in http://deeplearning.net/tutorial/rbm.html.