Alescontrela / viper_rl

Using advances in generative modeling to learn reward functions from unlabeled videos.
MIT License
111 stars 12 forks source link

Weird bug for BroadcastPositionBiases #4

Open alfaevc opened 9 months ago

alfaevc commented 9 months ago

Seems that there is a dimensionality issue here when I run train_videogpt.py ?

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/gpfs/data/oermannlab/users/qp2040/viper_rl/scripts/train_videogpt.py", line 272, in main() File "/gpfs/data/oermannlab/users/qp2040/viper_rl/scripts/train_videogpt.py", line 79, in main visualize(sampler, ae, iteration, state, test_loader) File "/gpfs/data/oermannlab/users/qp2040/viper_rl/scripts/train_videogpt.py", line 225, in visualize samples = sampler(variables, batch).copy() File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viperrl/videogpt/sampler.py", line 90, in call , cache = self._model_step( File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/sampler.py", line 38, in _fn logits, cache = self.model.apply( File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/models/videogpt.py", line 38, in call return self.model( File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/models/transformer.py", line 31, in call position_bias = BroadcastPositionBiases(shape=self.shape)(x) File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/models/transformer.py", line 188, in call embs = [ File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viperrl/videogpt/models/transformer.py", line 189, in self.param(f'd{i}', nn.initializers.normal(stddev=0.02), flax.errors.ScopeParamShapeError: Initializer expected to generate shape (16, 85) but got shape (8, 85) instead for parameter "d_1" in "/model/BroadcastPositionBiases_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)