Closed PikaPei closed 8 months ago
Thanks for the report. It can be easily correct it by changing the line
out_fr = jnp.mean(outs, axis=0)
as
out_fr = bm.mean(outs, axis=0)
Thanks for solving this! I'll make a PR to correct similar errors in this file.
Came from #9. Thanks for solving that problem! There's another error in the same example: