Open TugdualKerjan opened 1 month ago
Hello, I noticed that running the below code:
grab1, grab2 = jax.random.split(jax.random.PRNGKey(seed=69), 2) conv = nn.ConvTranspose1d(512, 256, kernel_size=16, stride=8, padding=4, key=grab1) x = jax.random.normal(key=grab2, shape=(512, 100))
Produces an error that seems to stem from the padding, as if I run
grab1, grab2 = jax.random.split(jax.random.PRNGKey(seed=69), 2) conv = nn.ConvTranspose1d(512, 256, kernel_size=16, stride=8, padding=((2, 6),), key=grab1) x = jax.random.normal(key=grab2, shape=(512, 100))
It works. I might have a misunderstanding of how Conv1d works though ! This library is amazing, thank you for the fantastic work 💯
Can you give the traceback and message for the error you obtain, and the versions of JAX and Equinox you are using? With JAX 0.4.34 and Equinox 0.11.8 I am unable to obtain an error with your code.
Hello, I noticed that running the below code:
Produces an error that seems to stem from the padding, as if I run
It works. I might have a misunderstanding of how Conv1d works though ! This library is amazing, thank you for the fantastic work 💯