patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

ConvTranspose1d bug #884

Open TugdualKerjan opened 1 month ago

TugdualKerjan commented 1 month ago

Hello, I noticed that running the below code:

grab1, grab2 = jax.random.split(jax.random.PRNGKey(seed=69), 2)
conv = nn.ConvTranspose1d(512, 256, kernel_size=16, stride=8, padding=4, key=grab1)

x = jax.random.normal(key=grab2, shape=(512, 100))

Produces an error that seems to stem from the padding, as if I run

grab1, grab2 = jax.random.split(jax.random.PRNGKey(seed=69), 2)
conv = nn.ConvTranspose1d(512, 256, kernel_size=16, stride=8, padding=((2, 6),), key=grab1)

x = jax.random.normal(key=grab2, shape=(512, 100))

It works. I might have a misunderstanding of how Conv1d works though ! This library is amazing, thank you for the fantastic work 💯

patrick-kidger commented 1 month ago

Can you give the traceback and message for the error you obtain, and the versions of JAX and Equinox you are using? With JAX 0.4.34 and Equinox 0.11.8 I am unable to obtain an error with your code.