Seems that there is a dimensionality issue here when I run train_videogpt.py ?
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/scripts/train_videogpt.py", line 272, in
main()
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/scripts/train_videogpt.py", line 79, in main
visualize(sampler, ae, iteration, state, test_loader)
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/scripts/train_videogpt.py", line 225, in visualize
samples = sampler(variables, batch).copy()
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viperrl/videogpt/sampler.py", line 90, in call, cache = self._model_step(
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/sampler.py", line 38, in _fn
logits, cache = self.model.apply(
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/models/videogpt.py", line 38, in call
return self.model(
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/models/transformer.py", line 31, in call
position_bias = BroadcastPositionBiases(shape=self.shape)(x)
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/models/transformer.py", line 188, in call
embs = [
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viperrl/videogpt/models/transformer.py", line 189, in
self.param(f'd{i}', nn.initializers.normal(stddev=0.02),
flax.errors.ScopeParamShapeError: Initializer expected to generate shape (16, 85) but got shape (8, 85) instead for parameter "d_1" in "/model/BroadcastPositionBiases_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)
Seems that there is a dimensionality issue here when I run train_videogpt.py ?
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/gpfs/data/oermannlab/users/qp2040/viper_rl/scripts/train_videogpt.py", line 272, in
main()
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/scripts/train_videogpt.py", line 79, in main
visualize(sampler, ae, iteration, state, test_loader)
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/scripts/train_videogpt.py", line 225, in visualize
samples = sampler(variables, batch).copy()
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viperrl/videogpt/sampler.py", line 90, in call
, cache = self._model_step(
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/sampler.py", line 38, in _fn
logits, cache = self.model.apply(
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/models/videogpt.py", line 38, in call
return self.model(
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/models/transformer.py", line 31, in call
position_bias = BroadcastPositionBiases(shape=self.shape)(x)
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/models/transformer.py", line 188, in call
embs = [
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viperrl/videogpt/models/transformer.py", line 189, in
self.param(f'd {i}', nn.initializers.normal(stddev=0.02),
flax.errors.ScopeParamShapeError: Initializer expected to generate shape (16, 85) but got shape (8, 85) instead for parameter "d_1" in "/model/BroadcastPositionBiases_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)