### Main Loop¶

We now have all the necessary ingredients to start training our network.

Before going over the training loop however, the reader should familiarize
himself with the function `tile_raster_images` (see *Plotting Samples and Filters*). Since
RBMs are generative models, we are interested in sampling from them and
plotting/visualizing these samples. We also want to visualize the filters
(weights) learnt by the RBM, to gain insights into what the RBM is actually
doing. Bear in mind however, that this does not provide the entire story,
since we neglect the biases and plot the weights up to a multiplicative
constant (weights are converted to values between 0 and 1).

Having these utility functions, we can start training the RBM and plot/save the filters after each training epoch. We train the RBM using PCD, as it has been shown to lead to a better generative model ([Tieleman08]).

```
# it is ok for a theano function to have no output
# the purpose of train_rbm is solely to update the RBM parameters
train_rbm = theano.function(
[index],
cost,
updates=updates,
givens={
x: train_set_x[index * batch_size: (index + 1) * batch_size]
},
name='train_rbm'
)
plotting_time = 0.
start_time = time.clock()
# go through training epochs
for epoch in xrange(training_epochs):
# go through the training set
mean_cost = []
for batch_index in xrange(n_train_batches):
mean_cost += [train_rbm(batch_index)]
print 'Training epoch %d, cost is ' % epoch, numpy.mean(mean_cost)
# Plot filters after each training epoch
plotting_start = time.clock()
# Construct image from the weight matrix
image = Image.fromarray(
tile_raster_images(
X=rbm.W.get_value(borrow=True).T,
img_shape=(28, 28),
tile_shape=(10, 10),
tile_spacing=(1, 1)
)
)
image.save('filters_at_epoch_%i.png' % epoch)
plotting_stop = time.clock()
plotting_time += (plotting_stop - plotting_start)
end_time = time.clock()
pretraining_time = (end_time - start_time) - plotting_time
print ('Training took %f minutes' % (pretraining_time / 60.))
```

Once the RBM is trained, we can then use the `gibbs_vhv` function to implement
the Gibbs chain required for sampling. We initialize the Gibbs chain starting
from test examples (although we could as well pick it from the training set)
in order to speed up convergence and avoid problems with random
initialization. We again use Theano’s `scan` op to do 1000 steps before
each plotting.

```
#################################
# Sampling from the RBM #
#################################
# find out the number of test samples
number_of_test_samples = test_set_x.get_value(borrow=True).shape[0]
# pick random test examples, with which to initialize the persistent chain
test_idx = rng.randint(number_of_test_samples - n_chains)
persistent_vis_chain = theano.shared(
numpy.asarray(
test_set_x.get_value(borrow=True)[test_idx:test_idx + n_chains],
dtype=theano.config.floatX
)
)
```

Next we create the 20 persistent chains in parallel to get our samples. To do so, we compile a theano function which performs one Gibbs step and updates the state of the persistent chain with the new visible sample. We apply this function iteratively for a large number of steps, plotting the samples at every 1000 steps.

```
plot_every = 1000
# define one step of Gibbs sampling (mf = mean-field) define a
# function that does `plot_every` steps before returning the
# sample for plotting
(
[
presig_hids,
hid_mfs,
hid_samples,
presig_vis,
vis_mfs,
vis_samples
],
updates
) = theano.scan(
rbm.gibbs_vhv,
outputs_info=[None, None, None, None, None, persistent_vis_chain],
n_steps=plot_every
)
# add to updates the shared variable that takes care of our persistent
# chain :.
updates.update({persistent_vis_chain: vis_samples[-1]})
# construct the function that implements our persistent chain.
# we generate the "mean field" activations for plotting and the actual
# samples for reinitializing the state of our persistent chain
sample_fn = theano.function(
[],
[
vis_mfs[-1],
vis_samples[-1]
],
updates=updates,
name='sample_fn'
)
# create a space to store the image for plotting ( we need to leave
# room for the tile_spacing as well)
image_data = numpy.zeros(
(29 * n_samples + 1, 29 * n_chains - 1),
dtype='uint8'
)
for idx in xrange(n_samples):
# generate `plot_every` intermediate samples that we discard,
# because successive samples in the chain are too correlated
vis_mf, vis_sample = sample_fn()
print ' ... plotting sample ', idx
image_data[29 * idx:29 * idx + 28, :] = tile_raster_images(
X=vis_mf,
img_shape=(28, 28),
tile_shape=(1, n_chains),
tile_spacing=(1, 1)
)
# construct image
image = Image.fromarray(image_data)
image.save('samples.png')
```