bshillingford / python-torchfile

Deserialize Lua torch-serialized objects from Python
BSD 3-Clause "New" or "Revised" License
216 stars 25 forks source link
deep-learning lua machine-learning python torch

Torch serialization reader for Python

Build Status Coverage Status

Mostly direct port of the torch7 Lua and C serialization implementation to Python, depending only on numpy (and the standard library: array and struct). Sharing of objects including torch.Tensors is preserved.

import torchfile
stuff = torchfile.load('a_bunch_of_stuff.t7')

Installation:

Install from PyPI:

pip install torchfile

or clone this repository, then:

python setup.py install

Supports Python 2.7, 3.4, 3.5, 3.6. Probably others too.

More examples:

Write from torch, read from Python:

Lua:

+th> torch.save('/tmp/test.t7', {hello=123, world=torch.rand(1,2,3)})

Python:

In [3]: o = torchfile.load('/tmp/test.t7')
In [4]: print o['world'].shape
(1, 2, 3)
In [5]: o
Out[5]: 
{'hello': 123, 'world': array([[[ 0.52291083,  0.29261517,  0.11113465],
         [ 0.01017287,  0.21466237,  0.26572137]]])}

Arbitary torch classes supported:

In [1]: import torchfile

In [2]: o = torchfile.load('testfiles_x86_64/gmodule_with_linear_identity.t7')

In [3]: o.forwardnodes[3].data.module
Out[3]: TorchObject(nn.Identity, {'output': array([], dtype=float64), 'gradInput': array([], dtype=float64)})

In [4]: for node in o.forwardnodes: print(repr(node.data.module))                                                                                                            
None
None
None
TorchObject(nn.Identity, {'output': array([], dtype=float64), 'gradInput': array([], dtype=float64)})
None
TorchObject(nn.Identity, {'output': array([], dtype=float64), 'gradInput': array([], dtype=float64)})
TorchObject(nn.Linear, {'weight': array([[-0.0248373 ],
       [ 0.17503954]]), 'gradInput': array([], dtype=float64), 'gradWeight': array([[  1.22317168e-312],
       [  1.22317168e-312]]), 'bias': array([ 0.05159848, -0.25367146]), 'gradBias': array([  1.22317168e-312,   1.22317168e-312]), 'output': array([], dtype=float64)})
TorchObject(nn.CAddTable, {'output': array([], dtype=float64), 'gradInput': []})
None

In [5]: o.forwardnodes[6].data.module.weight
Out[5]: 
array([[-0.0248373 ],
       [ 0.17503954]])

In [6]: o.forwardnodes[6].data.module.bias
Out[6]: array([ 0.05159848, -0.25367146])

More complex writing from torch:

Lua:

+th> f = torch.DiskFile('/tmp/test.t7', 'w'):binary()
+th> f:writeBool(false)
+th> f:writeObject({hello=123})
+th> f:writeInt(456)
+th> f:close()

Python:

In [1]: import torchfile
In [2]: with open('/tmp/test.t7','rb') as f:
   ...:     r = torchfile.T7Reader(f)
   ...:     print(r.read_boolean())
   ...:     print(r.read_obj())
   ...:     print(r.read_int())
   ...: 
False
{'hello': 123}
456

Supported types:

(*) Since Lua allows you to index a table with a table but Python does not, we replace dicts with a subclass that is hashable, and change its equality comparison behaviour to compare by reference. See hashable_uniq_dict.

Test files demonstrating various features:

In [1]: import torchfile

In [2]: torchfile.load('testfiles_x86_64/list_table.t7')
Out[2]: ['hello', 'world', 'third item', 123]

In [3]: torchfile.load('testfiles_x86_64/doubletensor.t7')
Out[3]: 
array([[ 1. ,  2. ,  3. ],
       [ 4. ,  5. ,  6.9]])

# ...also other files demonstrating various types.

The example t7 files will work on any modern Intel or AMD 64-bit CPU, but the code will use the native byte ordering etc. Currently, the implementation assumes the system-dependent binary Torch format, but minor refactoring can give support for the ascii format as well.