juliuskunze / jaxnet

Concise deep learning for JAX
Apache License 2.0
184 stars 14 forks source link

Repair unbatched PixelCNN example #5

Closed juliuskunze closed 4 years ago

juliuskunze commented 5 years ago

The unbatched PixelCNN example fails with

Exception: Can't lift Traced<ShapedArray(float32[32,32,1]):JaxprTrace(level=6/0)> to JaxprTrace(level=5/0)

while initializing parameters in down_shifted_conv during trace_to_jaxpr. An older version of the network was working on an unbatched input in jaxnet==0.1.4 and jax==0.1.41, so this is probably a bug in the tracing logic of jaxnet.core.