mfinzi / equivariant-MLP

A library for programmatically generating equivariant layers through constraint solving
MIT License
251 stars 21 forks source link

Saving and Loading Objax EMLPs yields slightly different predictions #8

Open mfinzi opened 3 years ago

mfinzi commented 3 years ago

It appears loading up an EMLP models saved with objax.io.save_var_collection yields slightly different predictions than the original model.

import emlp
from emlp.groups import SO
from emlp.reps import T,V
import numpy as np
import objax

net = emlp.nn.EMLP(T(1)+T(2),T(0),SO(3))
x=np.linspace(0,1,12)
print(net(x).T)

# Saving with file descriptor
with open('net.npz', 'wb') as f:
    objax.io.save_var_collection(f, net.vars())

Output: [0.4905533]

import emlp
from emlp.groups import S
from emlp.reps import T,V
import numpy as np
import objax

net = emlp.nn.EMLP(T(1)+T(2),T(0),SO(3))
x=np.linspace(0,1,12)
# Loading with file descriptor
with open('net.npz', 'rb') as  f:
    objax.io.load_var_collection(f, net.vars())

print(net(x).T)

Output [0.4904544]

mfinzi commented 3 years ago

It looks like this is due to randomness in the bilinear layer that is not captured as objax state variables. A workaround is to use the same numpy seed when initializing the model in both cases:

import emlp
from emlp.groups import SO
from emlp.reps import T,V
import numpy as np
import objax

np.random.seed(42)
net = emlp.nn.EMLP(T(1)+T(2),T(0),SO(3))
x=np.linspace(0,1,12)
print(net(x).T)

# Saving with file descriptor
with open('net.npz', 'wb') as f:
    objax.io.save_var_collection(f, net.vars())

Output: [0.03251831]

import emlp
from emlp.groups import S
from emlp.reps import T,V
import numpy as np
import objax

np.random.seed(42)
net = emlp.nn.EMLP(T(1)+T(2),T(0),SO(3))
x=np.linspace(0,1,12)
# Loading with file descriptor
with open('net.npz', 'rb') as  f:
    objax.io.load_var_collection(f, net.vars())

print(net(x).T)

Output [0.03251831]