Open nissy-dev opened 4 years ago
This is another example
import jax.numpy as jnp
import jax.lax as lax
def main():
iterator = iter(range(10))
def iterator_fori_loop():
def body_fun(i, cnt):
cnt += next(iterator)
return cnt
return lax.fori_loop(0, 10, body_fun, 0)
# cnt was expected to be 45, but cnt was 0
cnt = iterator_fori_loop()
print(cnt)
array = jnp.arange(10)
def array_fori_loop():
def body_fun(i, cnt):
cnt += array[i]
return cnt
return lax.fori_loop(0, 10, body_fun, 0)
# cnt got 45, this result was expected
cnt = array_fori_loop()
print(cnt)
main()
Thanks for the question!
Updating an iterator is a side-effect, and so it doesn't work with JAX's functional programming model: the body function passed to fori_loop
must not have side effects.
Do you have a suggestion of where we could improve the documentation to make this clearer?
This reminds me of the idea in #1141.
@mattjj Thanks for a quick response!
the body function passed to fori_loop must not have side effects
I could understand! This is natural because JAX is functional programming style.
Do you have a suggestion of where we could improve the documentation to make this clearer?
I seem that JAX Frequently Asked Questions (FAQ) or Structured control flow primitives are good positions.
Sample codes in Structured control flow primitives
were very useful when I tried to use fori_loop
or scan
Hi @nissy-dev
The PR #3632 to include the documentation on the usage of iterator
with fori_loop
has been merged and the same is reflected in the updated JAX documentation.
Could you please verify and confirm the same
Thank you.
Hi! I have one question about using iterator with fori_loop.
I wrote the following codes, but the iterator in the
loop_fun
doesn't work well.... Thebatch
value is always same in a loop.What should I do in order to make this iterator works correctly? I think I don't want to convert
train_iterator
tojax.numpy.array
as much as possible.