Open leson207 opened 2 weeks ago
Hey, currently
assert jax.nn.swish is nnx.swish
so I'm not sure what the issue could be here.
Hey, currently
assert jax.nn.swish is nnx.swish
so I'm not sure what the issue could be here.
I know that them the same, i mean when i use these functions, which build on each other or just other name(silu,swish), it change my input shape
System information
Problem you have encountered:
print(xBC.shape) xBC = jax.nn.swish(x) or xBC = nnx.swish(x) print(xBC.shape)
output shape is not the same as expected
What you expected to happen:
Output: (1, 128, 288) (1, 128, 128) Expect: (1, 128, 288) (1, 128, 288)