google / objax

Apache License 2.0
769 stars 77 forks source link

Restoring variable into module with extra layer fails incorrectly #146

Closed carlini closed 3 years ago

carlini commented 3 years ago

The following test case creates two modules that are identical, except the second module has a no-op layer that does nothing. Saving the first varcollection and restoring into the second crashes (which, it should work in the first place) but then also gives a really strange reason.

Code to reproduce

def test():
    nn = objax.nn.Sequential([objax.nn.Linear(1, 2),
                              objax.nn.Linear(2, 3),
                              objax.nn.Linear(3, 4),
    ])

    class NOP(objax.Module):
        def __call__(self, x):
            return x

    nn2 = objax.nn.Sequential([objax.nn.Linear(1, 2),
                               NOP(),
                               objax.nn.Linear(2, 3),
                               objax.nn.Linear(3, 4),
    ])

    ckpt = objax.io.Checkpoint("/tmp/test_ckpt", 1, makedir=True)

    print("Printing nn")
    print(nn.vars())
    print("Printing nn2")
    print(nn2.vars())

    ckpt.save(nn.vars(), 0)
    ckpt.restore(nn.vars(), 0)
    print("Ok nn")

    ckpt.save(nn2.vars(), 0)
    ckpt.restore(nn2.vars(), 0)
    print("Ok nn2")

    ckpt.save(nn.vars(), 0)
    ckpt.restore(nn2.vars(), 0)
    print("Ok cross")

Output:

Printing nn
(Sequential)[0](Linear).b        2 (2,)
(Sequential)[0](Linear).w        2 (1, 2)
(Sequential)[1](Linear).b        3 (3,)
(Sequential)[1](Linear).w        6 (2, 3)
(Sequential)[2](Linear).b        4 (4,)
(Sequential)[2](Linear).w       12 (3, 4)
+Total(6)                       29
Printing nn2
(Sequential)[0](Linear).b        2 (2,)
(Sequential)[0](Linear).w        2 (1, 2)
(Sequential)[1](Linear).b        3 (3,)
(Sequential)[1](Linear).w        6 (2, 3)
(Sequential)[3](Linear).b        4 (4,)
(Sequential)[3](Linear).w       12 (3, 4)
+Total(6)                       29
Resuming from /tmp/test_ckpt/ckpt/0000000000.npz
Ok nn
Resuming from /tmp/test_ckpt/ckpt/0000000000.npz
Ok nn2
Resuming from /tmp/test_ckpt/ckpt/0000000000.npz
Traceback (most recent call last):
  File "/home/ncarlini/diagnosing-failures/attack.py", line 146, in test
    ckpt.restore(nn2.vars(), 0)
  File "../backdoor/objax/objax/io/checkpoint.py", line 102, in restore
    self.LOAD_FN(ckpt, vc)
  File "../backdoor/objax/objax/io/ops.py", line 66, in load_var_collection
    raise AssertionError(f'Error when restoring variable {name}: {str(e)}') from None
AssertionError: Error when restoring variable (Sequential)[2](Linear).b: Assign can not change shape of variable. The current variable shape is (3,), but the requested new shape is (4,).

But wait, there's more. If you change the nn2 definition to

    nn2 = objax.nn.Sequential([objax.nn.Linear(1, 2),
                               objax.nn.Linear(2, 3),
                               NOP(),
                               objax.nn.Linear(3, 4),
    ])

then now the error is okay, and says something (slightly) more sane

  File "../backdoor/objax/objax/io/ops.py", line 72, in load_var_collection
    raise ValueError(f'Missing value for variables currently in the model: {misses}. '
ValueError: Missing value for variables currently in the model: ['(Sequential)[3](Linear).b', '(Sequential)[3](Linear).w']. The following variables on disk were not used, maybe the missing variable was renamed from one of these: {'(Sequential)[2](Linear).w', '(Sequential)[2](Linear).b'}.

even though I might expect, that since this is a varcollection, everything might work. But I can get over this one crashing because of how names are assigned.

david-berthelot commented 3 years ago

I don't get the same output when running your sample code on 1.2.0

Printing nn
(Sequential)[0](Linear).b        2 (2,)
(Sequential)[0](Linear).w        2 (1, 2)
(Sequential)[1](Linear).b        3 (3,)
(Sequential)[1](Linear).w        6 (2, 3)
(Sequential)[2](Linear).b        4 (4,)
(Sequential)[2](Linear).w       12 (3, 4)
+Total(6)                       29
Printing nn2
(Sequential)[0](Linear).b        2 (2,)
(Sequential)[0](Linear).w        2 (1, 2)
(Sequential)[2](Linear).b        3 (3,)
(Sequential)[2](Linear).w        6 (2, 3)
(Sequential)[3](Linear).b        4 (4,)
(Sequential)[3](Linear).w       12 (3, 4)
+Total(6)                       29
Resuming from /tmp/test_ckpt/ckpt/0000000000.npz
Ok nn
Resuming from /tmp/test_ckpt/ckpt/0000000000.npz
Ok nn2
Resuming from /tmp/test_ckpt/ckpt/0000000000.npz
Traceback (most recent call last):
  File "/home/dberth/jax3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3343, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-96ab3b1a310e>", line 1, in <module>
    runfile('/home/dberth/Code/objax/delme.py', wdir='/home/dberth/Code/objax')
  File "/home/dberth/.local/share/JetBrains/IntelliJIdea2020.2/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/home/dberth/.local/share/JetBrains/IntelliJIdea2020.2/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/dberth/Code/objax/delme.py", line 36, in <module>
    ckpt.restore(nn2.vars(), 0)
  File "/home/dberth/Code/objax/objax/io/checkpoint.py", line 102, in restore
    self.LOAD_FN(ckpt, vc)
  File "/home/dberth/Code/objax/objax/io/ops.py", line 66, in load_var_collection
    raise AssertionError(f'Error when restoring variable {name}: {str(e)}') from None
AssertionError: Error when restoring variable (Sequential)[2](Linear).b: Assign can not change shape of variable. The current variable shape is (3,), but the requested new shape is (4,).

Namely nn2 output is different, your output does not have [2] while mine does. Which basically explains the message, while the variables have the same names, due to change in architecture they are not the same, [2] in nn2 is really [1] in nn.

So it seems to make sense to me or am I missing something?

carlini commented 3 years ago

Er. Yes. I think this is actually all correct... I was reading the layers off by one.