Closed vadmbertr closed 5 months ago
replace lax.while_loop with lax.scan as the latter is rev-mode differentiable
lax.while_loop
lax.scan
replace
lax.while_loop
withlax.scan
as the latter is rev-mode differentiable