google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.02k stars 637 forks source link

nnx.Swish, jax.swish,... change the input shape #4214

Open leson207 opened 2 weeks ago

leson207 commented 2 weeks ago

System information

Problem you have encountered:

print(xBC.shape) xBC = jax.nn.swish(x) or xBC = nnx.swish(x) print(xBC.shape)

output shape is not the same as expected

What you expected to happen:

Output: (1, 128, 288) (1, 128, 128) Expect: (1, 128, 288) (1, 128, 288)

cgarciae commented 2 weeks ago

Hey, currently

assert jax.nn.swish is nnx.swish

so I'm not sure what the issue could be here.

leson207 commented 2 weeks ago

Hey, currently

assert jax.nn.swish is nnx.swish

so I'm not sure what the issue could be here.

I know that them the same, i mean when i use these functions, which build on each other or just other name(silu,swish), it change my input shape