jc-bao / policy-adaptation-survey

This repository is for comparing the prevailing adaptive control method in both control and learning communities.
Apache License 2.0
7 stars 1 forks source link

Jax logging issue. #35

Open jc-bao opened 1 year ago

jc-bao commented 1 year ago

I want to logging the following PPO training script writing in jax in real-time. But currently, I can only logging all related metrics and plotting with plt.plot(out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1)) when everything is done. The training code is as follows:

train_jit = jax.jit(make_train(config))
out = jax.block_until_ready(train_jit(rng))

Help me to modify the following training code so that it can logging status in real time. Please not the train function need to be jit. The following is the original training function code:

def make_train(config):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    env= ...

    def train(rng):
        # INIT ENV
        ...

        # INIT NETWORK
        ...

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                ...
                return runner_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            train_state, env_state, last_obs, rng, env_params = runner_state
            _, last_val = network.apply(train_state.params, last_obs)

            def _calculate_gae(traj_batch, last_val):
                ...
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    ...
                    return train_state, total_loss

                ...
                return update_state, total_loss

            ...
            return runner_state, metric

        ...
        runner_state, metric = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )
        return {"runner_state": runner_state, "metrics": metric}

    return train

In the implementation, the training process is called by

        runner_state, metric = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )

which make it hard to logging metrics accordingly. Modify the code so that I can log the training status in real-time.