Wednesday 7 February 2018

Saving and Loading Neural Networks

A very common question I get is how to save a neural network, and load it again later.


Why Save and Load?

There are two key scenarios when being able to save and load a neural network are useful.

  • During a long training period it is sometimes useful to stop and continue at a later time. This might be because you're using a laptop which can't remain on all the time. It could be because you want to stop the training and test how well the neural network performs. Being able to resume training at a different time is really helpful.
  • It is useful to share your trained neural network with others. Being able to save it, and for someone else to load it, is necessary for this to work.





What Do We Save?

In a neural network the thing that is doing the learning are the link weights. In our Python code, these are represented by matrices like wih and who. The wih matrix contains the weights for the links between the input and hidden layer, and the who matrix contains the weights for the links between the hidden and output layer.

If we save these matrices to a file, we can load them again later. That way we don't need to restart the training from the beginning.


Saving Numpy Arrays

The matrices wih and who are numpy arrays. Luckily the numpy library provides convenience functions for saving and load them.

The function to save a numpy array is numpy.save(filename, array). This will store array in filename. If we wanted to add a method to our neuralNetwork class, we could do it simply it like this:

# save neural network weights 
def save(self):
    numpy.save('saved_wih.npy', self.wih)
    numpy.save('saved_who.npy', self.who)
    pass

This will save the wih matrix as a file saved_wih.npy, and the wih matrix as a file saved_wih.npy.

If we want to stop the training we can issue n.save() in a notebook cell. We can then close down the notebook or even shut down the computer if we need to.


Loading Numpy Arrays

To load a numpy array we use array = numpy.load(filename). If we want to add a method to our neuralNetwork class, we should use the filenames we used to save the data.

# load neural network weights 
def load(self):
    self.wih = numpy.load('saved_wih.npy')
    self.who = numpy.load('saved_who.npy')
    pass

If we come back to our training, we need to run the notebook up to the point just before training. That   means running the Python code that sets up the neural network class, and sets the various parameters like the number of input nodes, the data source filenames, etc.

We can then issue n.load() in a notebook cell to load the previously saved neural networks weights back into the neural network object n.


Gotchas

We've kept the approach simple here, in line with our approach to learning about and coding simple neural networks. That means there are some things our very simple network saving and loading code doesn't do.

Our simple code only saves and loads the two wih and who weights matrices. It doesn't do anything else. It doesn't check that the loaded data matches the desired size of neural network. We need to make sure that if we load a saved neural network, we continue to use it with the same parameters. For example, we can't train a network, pause, and continue with different settings for the number of nodes in each layer.

If we want to share our neural network, they need to also be running the same Python code. The data we're passing them isn't rich enough to be independent of any particular neural network code. Efforts to develop such an open inter-operable data standard have started, for example the Open Neural Network Exchange Format.


HDF5 for Very Large Data

In some cases, with very large networks, the amount of data to be saved and loaded can be quite big. In my own experience from around 2016, the normal saving of bumpy arrays in this was didn't always work. I then fell back to a slightly more involved method to save and load data using the very mature HDF5 data format , popular in science and engineering.

The Anaconda Python distribution allows you to install the h5py package, which gives Python the ability to work with HDF5 data.

HDF5 data stores do more than the simple data saving and loading. They have the idea of a group or folder which can contain several data sets, such as numpy arrays. The data stores also keep account of data set names, and don't just blindly save data. For very large data sets, the data can be traverse and segmented on-disk without having to load it all into memory before subsets are taken.

You can explore more here: http://docs.h5py.org/en/latest/quick.html#quick