Closed rwightman closed 3 years ago
We started an effort on creating better error messages (since the JAX ones can lead to confusion) which include adding asserts in various places. Your proposal fits well into that project: https://github.com/google/objax/projects/1
isinstance
seems like a good idea.shape
also seems like a good idea since I don't see under what scenario modifying a variable value to one with a different shape would be useful. In that case, I would expect someone to replace a variable entirely rather than replacing its value. That being said, I'd love some feedback in case anyone can think of a scenario where that would make sense.This bug just bit me when accidentally restoring a checkpoint with different sized models. For example, if you have WRN-18,2 and restore a WRN-18,1 it'll clobber over the prior sizes with the new sizes and then you get very weird behavior.
I'm writing a PR to address both of these bugs.
Fixed in #123.
I made a few mistakes moving weights from PyTorch with vc.assign and ended up clobbering over all of the models weights with the wrong shape. No errors until you try to use the model.
I noticed the assign fn just uses
var = tensor
... no copy option since the jax array's are immutable, but wouldn't anassert isinstance(tensor, JaxArray) and self.var.shape == tensor.shape
be appropriate? Or possibly an attempt to convert the type to JaxArray and broadcast the shape...Am I missing use cases where you'd want to change type away from JaxArray or use a different shape than the original on assign()?