Closed carlini closed 3 years ago
I don't get the same output when running your sample code on 1.2.0
Printing nn
(Sequential)[0](Linear).b 2 (2,)
(Sequential)[0](Linear).w 2 (1, 2)
(Sequential)[1](Linear).b 3 (3,)
(Sequential)[1](Linear).w 6 (2, 3)
(Sequential)[2](Linear).b 4 (4,)
(Sequential)[2](Linear).w 12 (3, 4)
+Total(6) 29
Printing nn2
(Sequential)[0](Linear).b 2 (2,)
(Sequential)[0](Linear).w 2 (1, 2)
(Sequential)[2](Linear).b 3 (3,)
(Sequential)[2](Linear).w 6 (2, 3)
(Sequential)[3](Linear).b 4 (4,)
(Sequential)[3](Linear).w 12 (3, 4)
+Total(6) 29
Resuming from /tmp/test_ckpt/ckpt/0000000000.npz
Ok nn
Resuming from /tmp/test_ckpt/ckpt/0000000000.npz
Ok nn2
Resuming from /tmp/test_ckpt/ckpt/0000000000.npz
Traceback (most recent call last):
File "/home/dberth/jax3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3343, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-96ab3b1a310e>", line 1, in <module>
runfile('/home/dberth/Code/objax/delme.py', wdir='/home/dberth/Code/objax')
File "/home/dberth/.local/share/JetBrains/IntelliJIdea2020.2/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "/home/dberth/.local/share/JetBrains/IntelliJIdea2020.2/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/home/dberth/Code/objax/delme.py", line 36, in <module>
ckpt.restore(nn2.vars(), 0)
File "/home/dberth/Code/objax/objax/io/checkpoint.py", line 102, in restore
self.LOAD_FN(ckpt, vc)
File "/home/dberth/Code/objax/objax/io/ops.py", line 66, in load_var_collection
raise AssertionError(f'Error when restoring variable {name}: {str(e)}') from None
AssertionError: Error when restoring variable (Sequential)[2](Linear).b: Assign can not change shape of variable. The current variable shape is (3,), but the requested new shape is (4,).
Namely nn2
output is different, your output does not have [2]
while mine does.
Which basically explains the message, while the variables have the same names, due to change in architecture they are not the same, [2]
in nn2
is really [1]
in nn
.
So it seems to make sense to me or am I missing something?
Er. Yes. I think this is actually all correct... I was reading the layers off by one.
The following test case creates two modules that are identical, except the second module has a no-op layer that does nothing. Saving the first varcollection and restoring into the second crashes (which, it should work in the first place) but then also gives a really strange reason.
Code to reproduce
Output:
But wait, there's more. If you change the
nn2
definition tothen now the error is okay, and says something (slightly) more sane
even though I might expect, that since this is a
varcollection
, everything might work. But I can get over this one crashing because of how names are assigned.