Exception: Can't lift Traced<ShapedArray(float32[32,32,1]):JaxprTrace(level=6/0)> to JaxprTrace(level=5/0)
while initializing parameters in down_shifted_conv during trace_to_jaxpr. An older version of the network was working on an unbatched input in jaxnet==0.1.4 and jax==0.1.41, so this is probably a bug in the tracing logic of jaxnet.core.
The unbatched PixelCNN example fails with
while initializing parameters in
down_shifted_conv
duringtrace_to_jaxpr
. An older version of the network was working on an unbatched input injaxnet==0.1.4
andjax==0.1.41
, so this is probably a bug in the tracing logic ofjaxnet.core
.