jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.97k stars 2.75k forks source link

Using iterator with fori_loop #3567

Open nissy-dev opened 4 years ago

nissy-dev commented 4 years ago

Hi! I have one question about using iterator with fori_loop.

I wrote the following codes, but the iterator in the loop_fun doesn't work well.... The batch value is always same in a loop.

What should I do in order to make this iterator works correctly? I think I don't want to convert train_iterator to jax.numpy.array as much as possible.

@jax.jit
def run_epoch(params_and_state):
    """Update params and states for each epoch."""
    def loop_fun(idx, params_and_states):
        """Update function for params and states about each batch data."""
        params, state, batch = params_and_states

        // this iterator doesn't work... batch value is always same in a loop
        batch = next(train_iterator)

        (_, new_state), grads = jax.value_and_grad(loss, has_aux=True)(params, state, batch)
        updates, new_opt_state = optimizer.update(grads, opt_state)
        new_params = optix.apply_updates(params, updates)
        return (new_params, new_state, new_opt_state)

    return lax.fori_loop(0, train_num_batches, loop_fun, params_and_states)
nissy-dev commented 4 years ago

This is another example

import jax.numpy as jnp
import jax.lax as lax

def main():
    iterator = iter(range(10))

    def iterator_fori_loop():

        def body_fun(i, cnt):
            cnt += next(iterator)
            return cnt

        return lax.fori_loop(0, 10, body_fun, 0)

    # cnt was expected to be 45, but cnt was 0
    cnt = iterator_fori_loop()
    print(cnt)

    array = jnp.arange(10)
    def array_fori_loop():

        def body_fun(i, cnt):
            cnt += array[i]
            return cnt

        return lax.fori_loop(0, 10, body_fun, 0)

    # cnt got 45, this result was expected
    cnt = array_fori_loop()
    print(cnt)

main()
mattjj commented 4 years ago

Thanks for the question!

Updating an iterator is a side-effect, and so it doesn't work with JAX's functional programming model: the body function passed to fori_loop must not have side effects.

Do you have a suggestion of where we could improve the documentation to make this clearer?

mattjj commented 4 years ago

This reminds me of the idea in #1141.

nissy-dev commented 4 years ago

@mattjj Thanks for a quick response!

the body function passed to fori_loop must not have side effects

I could understand! This is natural because JAX is functional programming style.

Do you have a suggestion of where we could improve the documentation to make this clearer?

I seem that JAX Frequently Asked Questions (FAQ) or Structured control flow primitives are good positions.

Sample codes in Structured control flow primitives were very useful when I tried to use fori_loop or scan

rajasekharporeddy commented 7 months ago

Hi @nissy-dev

The PR #3632 to include the documentation on the usage of iterator with fori_loop has been merged and the same is reflected in the updated JAX documentation.

Could you please verify and confirm the same

Thank you.