juliuskunze / jaxnet

Concise deep learning for JAX
Apache License 2.0
184 stars 14 forks source link

Fix PixelCNN's weight norm layers #17

Closed juliuskunze closed 5 years ago

juliuskunze commented 5 years ago

The weight norm implementations dense and conv_or_conv_transpose from the pixelcnn example need custom batching rules + initialization.