google / objax

Apache License 2.0
769 stars 77 forks source link

Add type and shape check to var assign #106

Closed rwightman closed 3 years ago

rwightman commented 4 years ago

I made a few mistakes moving weights from PyTorch with vc.assign and ended up clobbering over all of the models weights with the wrong shape. No errors until you try to use the model.

I noticed the assign fn just uses var = tensor ... no copy option since the jax array's are immutable, but wouldn't an assert isinstance(tensor, JaxArray) and self.var.shape == tensor.shape be appropriate? Or possibly an attempt to convert the type to JaxArray and broadcast the shape...

Am I missing use cases where you'd want to change type away from JaxArray or use a different shape than the original on assign()?

david-berthelot commented 4 years ago

We started an effort on creating better error messages (since the JAX ones can lead to confusion) which include adding asserts in various places. Your proposal fits well into that project: https://github.com/google/objax/projects/1

  1. Assert isinstance seems like a good idea.
  2. Assert shape also seems like a good idea since I don't see under what scenario modifying a variable value to one with a different shape would be useful. In that case, I would expect someone to replace a variable entirely rather than replacing its value. That being said, I'd love some feedback in case anyone can think of a scenario where that would make sense.
carlini commented 4 years ago

This bug just bit me when accidentally restoring a checkpoint with different sized models. For example, if you have WRN-18,2 and restore a WRN-18,1 it'll clobber over the prior sizes with the new sizes and then you get very weird behavior.

I'm writing a PR to address both of these bugs.

carlini commented 3 years ago

Fixed in #123.