Open atarilover123 opened 2 years ago
I changed
jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape)
to
jnp.broadcast_to(last_sample_z, z_vals[..., :1].shape)
in line 107
of model_utils.py
and that fixed it for me.
Hi blackz5, It dose not work for me. Did you try it on the colab?
No, I have run it locally and changed the file as describe above.
I'm getting this error on the "initialize model" cell. UnfilteredStackTrace: TypeError: broadcast_to requires ndarray or scalar arguments, got <class 'list'> at position 0.