emstoudenmire / TNML

Tensor network machine learning. Based on the paper "Supervised Learning with Quantum Inspired Tensor Networks" http://arxiv.org/abs/1605.05775
MIT License
149 stars 52 forks source link
machine-learning matrix-product-states mps tensor-networks

NOTE: these codes are research, proof-of-principle codes only, and are not intended to demonstrate the state of the art in terms of training times for matrix product states for machine learning

If you are seeking fast approaches for optimizing MPS, we recommend trying newer libraries which use stochastic gradient optimization methods, such as TorchMPS: https://github.com/jemisjoky/TorchMPS

Tensor network machine learning

Codes based on the paper "Supervised Learning with Quantum-Inspired Tensor Networks" by Miles Stoudenmire and David Schwab. http://arxiv.org/abs/1605.05775

Also see "Tensor Train Polynomial Models via Riemannian Optimization" by Novikov, Trofimov, and Oseledets for a similar approach: http://arxiv.org/abs/1605.03795

Code Overview

fixedL -- optimize a matrix product state (MPS) with a label index on the central tensor, similar to what is described in the paper arxiv:1605.05775, but where the label index stays fixed on the central tensor and does not move around during optimization. This MPS parameterizes a model whose output is a vector of 10 numbers (for the case of MNIST). The output entry with the largest value is the predicted label.

fulltest -- given an MPS ("wavefunction") generated by the fixedL program, report classification error for the MNIST testing set

single -- optimize an MPS for a single label type, with no label index on the MPS. This MPS parameterizes a model whose output is positive for inputs of the correct type, and zero for all other inputs.

separate_fulltest -- report classification error for the MNIST testing set for a set of MPS created by the "single" application. IMPORTANT: this program assumes that the MPS W00, W01, W02, etc. made by running "single" reside in folders (which you have to create) named L00/, L01/, L02/ etc. So it looks for the files L00/W00, L01/W01, etc. under the folder where you run it.

Compiling and running the programs

Dependencies:

Steps to install and run:

  1. Install the above dependencies.
  2. Do cp Makefile.sample Makefile to create a Makefile from the sample provided.
  3. Edit the following variables at the top of your Makefile:
    • ITENSOR_DIR: this should be the folder where you git clone'd and installed ITensor (where the options.mk file is located)
    • LIBPNG_DIR: folder where the file libpng16.so (or libpng16.dylib on mac) is located (or change the name of the library if you install a different version of libpng)
    • PNGPP_DIR: folder where the png++ header (.hpp) files are located
  4. Run the command make, which should successfully build the fixedL application.
  5. Copy one of the sample input files from the folder sample_inputs/ to another folder of your choosing. Run each app by doing ./appname input_file_name.
  6. Edit the input file. At a minimum, change datadir to point to the location of the mllib/MNIST folder (inside of this repo) on your computer. Play around with the other settings such as Ntrain (max number of training images per label) to check basic things about the code before trying a heavy-duty calculation.

All of the codes require you to install the ITensor tensor network library. You can obtain it from http://github.com/ITensor/ITensor . The only software dependencies for ITensor are a compiler that supports C++11 (language and standard library) and a BLAS/LAPACK distribution such as the "lapack" package on linux, the Accelerate/Veclib framework on MacOS, or the Intel MKL library.

See http://itensor.org/ for help installing ITensor and for more documentation on it.

Once ITensor is installed, modify the first line of the provided Makefile to point to the ITensor installation folder. (Note: ITensor does not put files anywhere else on your computer; it just creates libraries inside its own folder.)

To use the Makefile, either just run make to build the default program (which is fixedL) or do make app=appname to compile the program appname (either fixedL, single, fulltest, or separate_fulltest).

Input files

Sample input files for fixedL and single are provided in the sample_inputs/ folder.

See below for a list of the possible input parameters to these programs and what they do.

FixedL program input parameters and code features

fixedL optimizes a matrix product state (MPS) with a label index on the central tensor, similar to what is described in the paper arxiv:1605.05775. This MPS parameterizes a model whose output is a vector of 10 numbers (for the case of MNIST). The output entry with the largest value is the predicted label.

One difference from the algorithm described in the paper is that the label index always remains on the same MPS tensor and is not moved around (although it can be moved, keeping it in a fixed position turns helps with the optimization).

Warning: fixedL can use a lot of RAM. If this happens, adjust the Nbatch parameter described below to make the program read smaller amounts of data into ram at each step.

Input parameters:

There are other input parameters of a more experimental nature, but the ones above are the most important.

Other code features:

Tips for running fixedL:

Single program input parameters and code features

single optimizes an MPS for a single label type, with no label index on the MPS. This MPS parameterizes a model whose output is (ideally) positive for inputs of the correct type, and zero for all other inputs.

The input parameters accepted by single are mostly the same as for fixedL above.

One important extra parameter needed by `single is the "label" parameter, which is an integer 0,1,...,9 telling the program which single label to "target" when optimizing the MPS.

When saving the currently optimized weight tensor MPS to disk, the single app appends the label number which that MPS is targeting. So if the label parameter is set to 3, the program will output the file "W03" (either when the program ends or the WRITE_WF file is found).

Tips for Using the Codes