bshillingford / python-torchfile

Deserialize Lua torch-serialized objects from Python
BSD 3-Clause "New" or "Revised" License
216 stars 25 forks source link

Extracting vales from a torch(.t7) object #8

Closed AakashKumarNain closed 7 years ago

AakashKumarNain commented 7 years ago

I was trying to convert a torch object to numpy array, more precisely to .npy. After loading the file like this t = torchfile.load('weights_file.t7'), I got a torchfile object. In the attached screen shots, you can see how the object looks like. How can I extract the arrays from this now ?

screenshot from 2017-05-13 15 02 44

screenshot from 2017-05-13 15 03 05

bshillingford commented 7 years ago

You can see its properties with .keys() or tab completion, like any other dictionary. A TorchObject is a lua table emulated in Python (so that all objects, including non-hashable ones, can be keys). If in lua you accessed the property containing your torch tensor with obj['a']['b'][123], you'd do the same in python to get the numpy array.

In the file you sent, that is not a single torch Tensor, it's clearly an nn.Module containing an entire network. Loop over the modules (either from Python or from Lua) to get the torch tensors or numpy arrays that you want.

ShangxuanWu commented 7 years ago

You can use something like np.array(cpt.modules[41].bias)